mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-18 15:41:11 +08:00
feat: 加大判别器参数
This commit is contained in:
parent
23e46263d8
commit
724b6612d3
3
.gitignore
vendored
3
.gitignore
vendored
@ -5,4 +5,5 @@ ginka-dataset.json
|
||||
ginka-eval.json
|
||||
minamo-dataset.json
|
||||
minamo-eval.json
|
||||
datasets
|
||||
datasets
|
||||
*.log
|
||||
5
cycle2.sh
Normal file
5
cycle2.sh
Normal file
@ -0,0 +1,5 @@
|
||||
for i in {$1...$2}
|
||||
do
|
||||
sh gan.sh "$i"
|
||||
echo "第 $i 次循环完成"
|
||||
done
|
||||
4
gan.sh
4
gan.sh
@ -8,8 +8,8 @@ python3 -m ginka.validate
|
||||
mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json"
|
||||
mv "minamo-eval.json" "datasets/minamo-eval-$1.json"
|
||||
cd data
|
||||
pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria/project" assigned:100:2
|
||||
pnpm minamo "../minamo-eval.json" "../result/ginka_val.json" "../../Apeiria-eval/project" assigned:100:2
|
||||
pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria/project" assigned:100:40
|
||||
pnpm minamo "../minamo-eval.json" "../result/ginka_val.json" "../../Apeiria-eval/project" assigned:100:10
|
||||
pnpm review "../minamo-dataset.json" "../datasets/minamo-dataset-merged.json"
|
||||
pnpm review "../minamo-eval.json" "../datasets/minamo-eval-merged.json"
|
||||
pnpm merge "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-$1.json"
|
||||
|
||||
@ -5,7 +5,7 @@ from .unet import GinkaUNet
|
||||
from .sample import MapDownSample
|
||||
|
||||
class GinkaModel(nn.Module):
|
||||
def __init__(self, feat_dim=256, base_ch=64, num_classes=32):
|
||||
def __init__(self, feat_dim=1024, base_ch=64, num_classes=32):
|
||||
"""Ginka Model 模型定义部分
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@ -2,9 +2,9 @@ import json
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset
|
||||
from shared.graph import convert_soft_map_to_graph
|
||||
from shared.graph import differentiable_convert_to_data
|
||||
|
||||
def random_smooth_onehot(onehot_map, min_main=0.65, max_main=1.0, epsilon=0.35):
|
||||
def random_smooth_onehot(onehot_map, min_main=0.75, max_main=1.0, epsilon=0.25):
|
||||
"""
|
||||
生成随机平滑的 one-hot 编码,使主类别概率不再固定,而是随机波动
|
||||
"""
|
||||
@ -46,8 +46,8 @@ class MinamoDataset(Dataset):
|
||||
map1_probs = random_smooth_onehot(map1_probs)
|
||||
map2_probs = random_smooth_onehot(map2_probs)
|
||||
|
||||
graph1 = convert_soft_map_to_graph(map1_probs)
|
||||
graph2 = convert_soft_map_to_graph(map2_probs)
|
||||
graph1 = differentiable_convert_to_data(map1_probs)
|
||||
graph2 = differentiable_convert_to_data(map2_probs)
|
||||
|
||||
return (
|
||||
map1_probs,
|
||||
|
||||
@ -7,7 +7,7 @@ from torch_geometric.data import Data
|
||||
|
||||
class MinamoTopoModel(nn.Module):
|
||||
def __init__(
|
||||
self, tile_types=32, emb_dim=64, hidden_dim=64, out_dim=512, mlp_dim=128
|
||||
self, tile_types=32, emb_dim=128, hidden_dim=128, out_dim=512, mlp_dim=512
|
||||
):
|
||||
super().__init__()
|
||||
# 传入 softmax 概率值,直接映射
|
||||
@ -15,16 +15,10 @@ class MinamoTopoModel(nn.Module):
|
||||
# 图卷积层
|
||||
self.conv1 = GATConv(emb_dim, hidden_dim*2, heads=8, dropout=0.2)
|
||||
self.conv2 = GATConv(hidden_dim*16, hidden_dim*4, heads=4)
|
||||
self.conv_ins2 = GATConv(hidden_dim*16, hidden_dim*4, heads=4, dropout=0.3)
|
||||
self.conv_ins2 = GATConv(hidden_dim*16, hidden_dim*4, heads=4)
|
||||
self.conv_ins1 = GATConv(hidden_dim*16, hidden_dim*8, heads=2)
|
||||
self.conv3 = GATConv(hidden_dim*16, out_dim, concat=False)
|
||||
|
||||
self.conv1.lin = spectral_norm(self.conv1.lin)
|
||||
self.conv2.lin = spectral_norm(self.conv2.lin)
|
||||
self.conv_ins2.lin = spectral_norm(self.conv_ins2.lin)
|
||||
self.conv_ins1.lin = spectral_norm(self.conv_ins1.lin)
|
||||
self.conv3.lin = spectral_norm(self.conv3.lin)
|
||||
|
||||
# 正则化
|
||||
self.norm1 = nn.LayerNorm(hidden_dim*16)
|
||||
self.norm2 = nn.LayerNorm(hidden_dim*16)
|
||||
|
||||
@ -5,33 +5,35 @@ from torch.nn.utils import spectral_norm
|
||||
from shared.attention import CBAM
|
||||
|
||||
class MinamoVisionModel(nn.Module):
|
||||
def __init__(self, tile_types=32, conv_ch=32, out_dim=128):
|
||||
def __init__(self, tile_types=32, conv_ch=64, out_dim=512):
|
||||
super().__init__()
|
||||
# 输入 softmax 概率值
|
||||
self.input_conv = nn.Conv2d(tile_types, conv_ch, 3, padding=1)
|
||||
|
||||
# 卷积部分
|
||||
self.vision_conv = nn.Sequential(
|
||||
spectral_norm(nn.Conv2d(conv_ch, conv_ch*2, 3, padding=1)),
|
||||
nn.Conv2d(conv_ch, conv_ch*2, 3, padding=1),
|
||||
nn.BatchNorm2d(conv_ch*2),
|
||||
CBAM(conv_ch*2),
|
||||
nn.GELU(),
|
||||
nn.MaxPool2d(2),
|
||||
nn.Dropout2d(0.4),
|
||||
|
||||
spectral_norm(nn.Conv2d(conv_ch*2, conv_ch*4, 3, padding=1)),
|
||||
nn.Conv2d(conv_ch*2, conv_ch*4, 3, padding=1),
|
||||
nn.BatchNorm2d(conv_ch*4),
|
||||
CBAM(conv_ch*4),
|
||||
nn.GELU(),
|
||||
nn.MaxPool2d(2),
|
||||
nn.Dropout2d(0.4),
|
||||
|
||||
spectral_norm(nn.Conv2d(conv_ch*4, conv_ch*8, 3, padding=1)),
|
||||
nn.Conv2d(conv_ch*4, conv_ch*8, 3, padding=1),
|
||||
nn.BatchNorm2d(conv_ch*8),
|
||||
CBAM(conv_ch*8),
|
||||
nn.GELU(),
|
||||
# nn.MaxPool2d(2),
|
||||
# nn.Dropout2d(0.4),
|
||||
|
||||
spectral_norm(nn.Conv2d(conv_ch*8, conv_ch*8, 3, padding=1)),
|
||||
nn.Conv2d(conv_ch*8, conv_ch*8, 3, padding=1),
|
||||
nn.BatchNorm2d(conv_ch*8),
|
||||
CBAM(conv_ch*8),
|
||||
nn.GELU(),
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
@ -13,6 +14,7 @@ from shared.args import parse_arguments
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
os.makedirs("result", exist_ok=True)
|
||||
os.makedirs("result/minamo_checkpoint", exist_ok=True)
|
||||
disable_tqdm = not sys.stdout.isatty() # 如果 stdout 被重定向,则禁用 tqdm
|
||||
|
||||
def collate_fn(batch):
|
||||
"""动态处理不同尺寸地图的批处理"""
|
||||
@ -56,7 +58,7 @@ def train():
|
||||
|
||||
# 设定优化器与调度器
|
||||
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=3, T_mult=2, eta_min=1e-6)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
||||
criterion = MinamoLoss()
|
||||
|
||||
if args.resume:
|
||||
@ -71,7 +73,7 @@ def train():
|
||||
# param.requires_grad = False
|
||||
|
||||
# 开始训练
|
||||
for epoch in tqdm(range(args.epochs)):
|
||||
for epoch in tqdm(range(args.epochs), disable=disable_tqdm):
|
||||
model.train()
|
||||
total_loss = 0
|
||||
|
||||
@ -79,7 +81,7 @@ def train():
|
||||
# for name, param in model.named_parameters():
|
||||
# param.requires_grad = True
|
||||
|
||||
for batch in tqdm(dataloader, leave=False):
|
||||
for batch in tqdm(dataloader, leave=False, disable=disable_tqdm):
|
||||
# 数据迁移到设备
|
||||
map1, map2, vision_simi, topo_simi, graph1, graph2 = batch
|
||||
map1 = map1.to(device) # 转为 [B, C, H, W]
|
||||
@ -108,7 +110,7 @@ def train():
|
||||
total_loss += loss.item()
|
||||
|
||||
ave_loss = total_loss / len(dataloader)
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
||||
print(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
||||
|
||||
# total_norm = 0
|
||||
# for p in model.parameters():
|
||||
@ -130,7 +132,7 @@ def train():
|
||||
model.eval()
|
||||
val_loss = 0
|
||||
with torch.no_grad():
|
||||
for val_batch in val_loader:
|
||||
for val_batch in tqdm(val_loader, leave=False, disable=disable_tqdm):
|
||||
map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = val_batch
|
||||
map1_val = map1_val.to(device)
|
||||
map2_val = map2_val.to(device)
|
||||
@ -150,7 +152,7 @@ def train():
|
||||
val_loss += loss_val.item()
|
||||
|
||||
avg_val_loss = val_loss / len(val_loader)
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
|
||||
print(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
|
||||
torch.save({
|
||||
"model_state": model.state_dict(),
|
||||
"optimizer_state": optimizer.state_dict(),
|
||||
|
||||
@ -1,6 +1,56 @@
|
||||
import torch
|
||||
from torch_geometric.data import Data, Batch
|
||||
|
||||
def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data:
|
||||
"""
|
||||
可导的图结构转换(返回PyG Data对象)
|
||||
map_probs: [C, H, W]
|
||||
返回:
|
||||
Data(x=[N,C], edge_index=[2,E], edge_attr=[E,C])
|
||||
"""
|
||||
C, H, W = map_probs.shape
|
||||
device = map_probs.device
|
||||
N = H * W
|
||||
|
||||
# 1. 节点特征(保留所有节点)
|
||||
node_features = map_probs.view(C, -1).T # [N, C]
|
||||
|
||||
# 2. 构建所有可能的边连接(预计算)
|
||||
# 生成坐标网格
|
||||
rows, cols = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
|
||||
node_indices = rows * W + cols
|
||||
|
||||
# 水平连接(右邻居)
|
||||
right_src = node_indices[:, :-1].flatten()
|
||||
right_dst = node_indices[:, 1:].flatten()
|
||||
|
||||
# 垂直连接(下邻居)
|
||||
down_src = node_indices[:-1, :].flatten()
|
||||
down_dst = node_indices[1:, :].flatten()
|
||||
|
||||
# 合并边列表
|
||||
edge_src = torch.cat([right_src, down_src]).to(device)
|
||||
edge_dst = torch.cat([right_dst, down_dst]).to(device)
|
||||
edge_index = torch.stack([edge_src, edge_dst]) # [2, E]
|
||||
|
||||
# 3. 计算可导的边权重(排除墙类型)
|
||||
wall_class_idx = 1 # 假设类型1是墙
|
||||
src_probs = 1.0 - map_probs[wall_class_idx].flatten()[edge_src] # [E]
|
||||
dst_probs = 1.0 - map_probs[wall_class_idx].flatten()[edge_dst] # [E]
|
||||
edge_mask = (src_probs * dst_probs).unsqueeze(1) # [E, 1]
|
||||
|
||||
# 4. 边特征计算(保持可导)
|
||||
src_feat = map_probs[:, edge_src//W, edge_src%W].T # [E, C]
|
||||
dst_feat = map_probs[:, edge_dst//W, edge_dst%W].T # [E, C]
|
||||
edge_attr = (src_feat + dst_feat) / 2 * edge_mask # [E, C]
|
||||
|
||||
return Data(
|
||||
x=node_features,
|
||||
edge_index=edge_index,
|
||||
edge_attr=edge_attr,
|
||||
num_nodes=N
|
||||
)
|
||||
|
||||
def convert_soft_map_to_graph(map_probs: torch.Tensor):
|
||||
"""
|
||||
直接使用 Softmax 概率构建 soft 图结构
|
||||
@ -40,7 +90,7 @@ def batch_convert_soft_map_to_graph(batch_map_probs):
|
||||
batch_graphs = []
|
||||
|
||||
for i in range(B):
|
||||
graph = convert_soft_map_to_graph(batch_map_probs[i]) # 处理单个样本
|
||||
graph = differentiable_convert_to_data(batch_map_probs[i]) # 处理单个样本
|
||||
batch_graphs.append(graph)
|
||||
|
||||
# 合并所有图为批量 Batch
|
||||
|
||||
Loading…
Reference in New Issue
Block a user