diff --git a/.gitignore b/.gitignore index 02bff83..bc2808a 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ ginka-dataset.json ginka-eval.json minamo-dataset.json minamo-eval.json -datasets \ No newline at end of file +datasets +*.log \ No newline at end of file diff --git a/cycle2.sh b/cycle2.sh new file mode 100644 index 0000000..98187c8 --- /dev/null +++ b/cycle2.sh @@ -0,0 +1,5 @@ +for i in {$1...$2} +do + sh gan.sh "$i" + echo "第 $i 次循环完成" +done diff --git a/gan.sh b/gan.sh index 5d26685..4616733 100644 --- a/gan.sh +++ b/gan.sh @@ -8,8 +8,8 @@ python3 -m ginka.validate mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json" mv "minamo-eval.json" "datasets/minamo-eval-$1.json" cd data -pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria/project" assigned:100:2 -pnpm minamo "../minamo-eval.json" "../result/ginka_val.json" "../../Apeiria-eval/project" assigned:100:2 +pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria/project" assigned:100:40 +pnpm minamo "../minamo-eval.json" "../result/ginka_val.json" "../../Apeiria-eval/project" assigned:100:10 pnpm review "../minamo-dataset.json" "../datasets/minamo-dataset-merged.json" pnpm review "../minamo-eval.json" "../datasets/minamo-eval-merged.json" pnpm merge "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-$1.json" diff --git a/ginka/model/model.py b/ginka/model/model.py index 06356d7..eddc047 100644 --- a/ginka/model/model.py +++ b/ginka/model/model.py @@ -5,7 +5,7 @@ from .unet import GinkaUNet from .sample import MapDownSample class GinkaModel(nn.Module): - def __init__(self, feat_dim=256, base_ch=64, num_classes=32): + def __init__(self, feat_dim=1024, base_ch=64, num_classes=32): """Ginka Model 模型定义部分 """ super().__init__() diff --git a/minamo/dataset.py b/minamo/dataset.py index 177fe8d..4dda47c 100644 --- a/minamo/dataset.py +++ b/minamo/dataset.py @@ -2,9 +2,9 @@ import json import torch import torch.nn.functional as F from torch.utils.data import Dataset -from shared.graph import convert_soft_map_to_graph +from shared.graph import differentiable_convert_to_data -def random_smooth_onehot(onehot_map, min_main=0.65, max_main=1.0, epsilon=0.35): +def random_smooth_onehot(onehot_map, min_main=0.75, max_main=1.0, epsilon=0.25): """ 生成随机平滑的 one-hot 编码,使主类别概率不再固定,而是随机波动 """ @@ -46,8 +46,8 @@ class MinamoDataset(Dataset): map1_probs = random_smooth_onehot(map1_probs) map2_probs = random_smooth_onehot(map2_probs) - graph1 = convert_soft_map_to_graph(map1_probs) - graph2 = convert_soft_map_to_graph(map2_probs) + graph1 = differentiable_convert_to_data(map1_probs) + graph2 = differentiable_convert_to_data(map2_probs) return ( map1_probs, diff --git a/minamo/model/topo.py b/minamo/model/topo.py index 931dbc3..57df620 100644 --- a/minamo/model/topo.py +++ b/minamo/model/topo.py @@ -7,7 +7,7 @@ from torch_geometric.data import Data class MinamoTopoModel(nn.Module): def __init__( - self, tile_types=32, emb_dim=64, hidden_dim=64, out_dim=512, mlp_dim=128 + self, tile_types=32, emb_dim=128, hidden_dim=128, out_dim=512, mlp_dim=512 ): super().__init__() # 传入 softmax 概率值,直接映射 @@ -15,16 +15,10 @@ class MinamoTopoModel(nn.Module): # 图卷积层 self.conv1 = GATConv(emb_dim, hidden_dim*2, heads=8, dropout=0.2) self.conv2 = GATConv(hidden_dim*16, hidden_dim*4, heads=4) - self.conv_ins2 = GATConv(hidden_dim*16, hidden_dim*4, heads=4, dropout=0.3) + self.conv_ins2 = GATConv(hidden_dim*16, hidden_dim*4, heads=4) self.conv_ins1 = GATConv(hidden_dim*16, hidden_dim*8, heads=2) self.conv3 = GATConv(hidden_dim*16, out_dim, concat=False) - self.conv1.lin = spectral_norm(self.conv1.lin) - self.conv2.lin = spectral_norm(self.conv2.lin) - self.conv_ins2.lin = spectral_norm(self.conv_ins2.lin) - self.conv_ins1.lin = spectral_norm(self.conv_ins1.lin) - self.conv3.lin = spectral_norm(self.conv3.lin) - # 正则化 self.norm1 = nn.LayerNorm(hidden_dim*16) self.norm2 = nn.LayerNorm(hidden_dim*16) diff --git a/minamo/model/vision.py b/minamo/model/vision.py index 2bd8572..20953b2 100644 --- a/minamo/model/vision.py +++ b/minamo/model/vision.py @@ -5,33 +5,35 @@ from torch.nn.utils import spectral_norm from shared.attention import CBAM class MinamoVisionModel(nn.Module): - def __init__(self, tile_types=32, conv_ch=32, out_dim=128): + def __init__(self, tile_types=32, conv_ch=64, out_dim=512): super().__init__() # 输入 softmax 概率值 self.input_conv = nn.Conv2d(tile_types, conv_ch, 3, padding=1) # 卷积部分 self.vision_conv = nn.Sequential( - spectral_norm(nn.Conv2d(conv_ch, conv_ch*2, 3, padding=1)), + nn.Conv2d(conv_ch, conv_ch*2, 3, padding=1), nn.BatchNorm2d(conv_ch*2), CBAM(conv_ch*2), nn.GELU(), nn.MaxPool2d(2), nn.Dropout2d(0.4), - spectral_norm(nn.Conv2d(conv_ch*2, conv_ch*4, 3, padding=1)), + nn.Conv2d(conv_ch*2, conv_ch*4, 3, padding=1), nn.BatchNorm2d(conv_ch*4), CBAM(conv_ch*4), nn.GELU(), nn.MaxPool2d(2), nn.Dropout2d(0.4), - spectral_norm(nn.Conv2d(conv_ch*4, conv_ch*8, 3, padding=1)), + nn.Conv2d(conv_ch*4, conv_ch*8, 3, padding=1), nn.BatchNorm2d(conv_ch*8), CBAM(conv_ch*8), nn.GELU(), + # nn.MaxPool2d(2), + # nn.Dropout2d(0.4), - spectral_norm(nn.Conv2d(conv_ch*8, conv_ch*8, 3, padding=1)), + nn.Conv2d(conv_ch*8, conv_ch*8, 3, padding=1), nn.BatchNorm2d(conv_ch*8), CBAM(conv_ch*8), nn.GELU(), diff --git a/minamo/train.py b/minamo/train.py index b2aed3e..e4416ad 100644 --- a/minamo/train.py +++ b/minamo/train.py @@ -1,4 +1,5 @@ import os +import sys from datetime import datetime import torch import torch.optim as optim @@ -13,6 +14,7 @@ from shared.args import parse_arguments device = torch.device("cuda" if torch.cuda.is_available() else "cpu") os.makedirs("result", exist_ok=True) os.makedirs("result/minamo_checkpoint", exist_ok=True) +disable_tqdm = not sys.stdout.isatty() # 如果 stdout 被重定向,则禁用 tqdm def collate_fn(batch): """动态处理不同尺寸地图的批处理""" @@ -56,7 +58,7 @@ def train(): # 设定优化器与调度器 optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2) - scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=3, T_mult=2, eta_min=1e-6) + scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6) criterion = MinamoLoss() if args.resume: @@ -71,7 +73,7 @@ def train(): # param.requires_grad = False # 开始训练 - for epoch in tqdm(range(args.epochs)): + for epoch in tqdm(range(args.epochs), disable=disable_tqdm): model.train() total_loss = 0 @@ -79,7 +81,7 @@ def train(): # for name, param in model.named_parameters(): # param.requires_grad = True - for batch in tqdm(dataloader, leave=False): + for batch in tqdm(dataloader, leave=False, disable=disable_tqdm): # 数据迁移到设备 map1, map2, vision_simi, topo_simi, graph1, graph2 = batch map1 = map1.to(device) # 转为 [B, C, H, W] @@ -108,7 +110,7 @@ def train(): total_loss += loss.item() ave_loss = total_loss / len(dataloader) - tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}") + print(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}") # total_norm = 0 # for p in model.parameters(): @@ -130,7 +132,7 @@ def train(): model.eval() val_loss = 0 with torch.no_grad(): - for val_batch in val_loader: + for val_batch in tqdm(val_loader, leave=False, disable=disable_tqdm): map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = val_batch map1_val = map1_val.to(device) map2_val = map2_val.to(device) @@ -150,7 +152,7 @@ def train(): val_loss += loss_val.item() avg_val_loss = val_loss / len(val_loader) - tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}") + print(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}") torch.save({ "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), diff --git a/shared/graph.py b/shared/graph.py index 17c15c6..48dcb75 100644 --- a/shared/graph.py +++ b/shared/graph.py @@ -1,6 +1,56 @@ import torch from torch_geometric.data import Data, Batch +def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data: + """ + 可导的图结构转换(返回PyG Data对象) + map_probs: [C, H, W] + 返回: + Data(x=[N,C], edge_index=[2,E], edge_attr=[E,C]) + """ + C, H, W = map_probs.shape + device = map_probs.device + N = H * W + + # 1. 节点特征(保留所有节点) + node_features = map_probs.view(C, -1).T # [N, C] + + # 2. 构建所有可能的边连接(预计算) + # 生成坐标网格 + rows, cols = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij') + node_indices = rows * W + cols + + # 水平连接(右邻居) + right_src = node_indices[:, :-1].flatten() + right_dst = node_indices[:, 1:].flatten() + + # 垂直连接(下邻居) + down_src = node_indices[:-1, :].flatten() + down_dst = node_indices[1:, :].flatten() + + # 合并边列表 + edge_src = torch.cat([right_src, down_src]).to(device) + edge_dst = torch.cat([right_dst, down_dst]).to(device) + edge_index = torch.stack([edge_src, edge_dst]) # [2, E] + + # 3. 计算可导的边权重(排除墙类型) + wall_class_idx = 1 # 假设类型1是墙 + src_probs = 1.0 - map_probs[wall_class_idx].flatten()[edge_src] # [E] + dst_probs = 1.0 - map_probs[wall_class_idx].flatten()[edge_dst] # [E] + edge_mask = (src_probs * dst_probs).unsqueeze(1) # [E, 1] + + # 4. 边特征计算(保持可导) + src_feat = map_probs[:, edge_src//W, edge_src%W].T # [E, C] + dst_feat = map_probs[:, edge_dst//W, edge_dst%W].T # [E, C] + edge_attr = (src_feat + dst_feat) / 2 * edge_mask # [E, C] + + return Data( + x=node_features, + edge_index=edge_index, + edge_attr=edge_attr, + num_nodes=N + ) + def convert_soft_map_to_graph(map_probs: torch.Tensor): """ 直接使用 Softmax 概率构建 soft 图结构 @@ -40,7 +90,7 @@ def batch_convert_soft_map_to_graph(batch_map_probs): batch_graphs = [] for i in range(B): - graph = convert_soft_map_to_graph(batch_map_probs[i]) # 处理单个样本 + graph = differentiable_convert_to_data(batch_map_probs[i]) # 处理单个样本 batch_graphs.append(graph) # 合并所有图为批量 Batch