feat: 加大判别器参数

This commit is contained in:
unanmed 2025-03-24 13:48:30 +08:00
parent 23e46263d8
commit 724b6612d3
9 changed files with 82 additions and 28 deletions

3
.gitignore vendored
View File

@ -5,4 +5,5 @@ ginka-dataset.json
ginka-eval.json
minamo-dataset.json
minamo-eval.json
datasets
datasets
*.log

5
cycle2.sh Normal file
View File

@ -0,0 +1,5 @@
for i in {$1...$2}
do
sh gan.sh "$i"
echo "$i 次循环完成"
done

4
gan.sh
View File

@ -8,8 +8,8 @@ python3 -m ginka.validate
mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json"
mv "minamo-eval.json" "datasets/minamo-eval-$1.json"
cd data
pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria/project" assigned:100:2
pnpm minamo "../minamo-eval.json" "../result/ginka_val.json" "../../Apeiria-eval/project" assigned:100:2
pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria/project" assigned:100:40
pnpm minamo "../minamo-eval.json" "../result/ginka_val.json" "../../Apeiria-eval/project" assigned:100:10
pnpm review "../minamo-dataset.json" "../datasets/minamo-dataset-merged.json"
pnpm review "../minamo-eval.json" "../datasets/minamo-eval-merged.json"
pnpm merge "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-$1.json"

View File

@ -5,7 +5,7 @@ from .unet import GinkaUNet
from .sample import MapDownSample
class GinkaModel(nn.Module):
def __init__(self, feat_dim=256, base_ch=64, num_classes=32):
def __init__(self, feat_dim=1024, base_ch=64, num_classes=32):
"""Ginka Model 模型定义部分
"""
super().__init__()

View File

@ -2,9 +2,9 @@ import json
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from shared.graph import convert_soft_map_to_graph
from shared.graph import differentiable_convert_to_data
def random_smooth_onehot(onehot_map, min_main=0.65, max_main=1.0, epsilon=0.35):
def random_smooth_onehot(onehot_map, min_main=0.75, max_main=1.0, epsilon=0.25):
"""
生成随机平滑的 one-hot 编码使主类别概率不再固定而是随机波动
"""
@ -46,8 +46,8 @@ class MinamoDataset(Dataset):
map1_probs = random_smooth_onehot(map1_probs)
map2_probs = random_smooth_onehot(map2_probs)
graph1 = convert_soft_map_to_graph(map1_probs)
graph2 = convert_soft_map_to_graph(map2_probs)
graph1 = differentiable_convert_to_data(map1_probs)
graph2 = differentiable_convert_to_data(map2_probs)
return (
map1_probs,

View File

@ -7,7 +7,7 @@ from torch_geometric.data import Data
class MinamoTopoModel(nn.Module):
def __init__(
self, tile_types=32, emb_dim=64, hidden_dim=64, out_dim=512, mlp_dim=128
self, tile_types=32, emb_dim=128, hidden_dim=128, out_dim=512, mlp_dim=512
):
super().__init__()
# 传入 softmax 概率值,直接映射
@ -15,16 +15,10 @@ class MinamoTopoModel(nn.Module):
# 图卷积层
self.conv1 = GATConv(emb_dim, hidden_dim*2, heads=8, dropout=0.2)
self.conv2 = GATConv(hidden_dim*16, hidden_dim*4, heads=4)
self.conv_ins2 = GATConv(hidden_dim*16, hidden_dim*4, heads=4, dropout=0.3)
self.conv_ins2 = GATConv(hidden_dim*16, hidden_dim*4, heads=4)
self.conv_ins1 = GATConv(hidden_dim*16, hidden_dim*8, heads=2)
self.conv3 = GATConv(hidden_dim*16, out_dim, concat=False)
self.conv1.lin = spectral_norm(self.conv1.lin)
self.conv2.lin = spectral_norm(self.conv2.lin)
self.conv_ins2.lin = spectral_norm(self.conv_ins2.lin)
self.conv_ins1.lin = spectral_norm(self.conv_ins1.lin)
self.conv3.lin = spectral_norm(self.conv3.lin)
# 正则化
self.norm1 = nn.LayerNorm(hidden_dim*16)
self.norm2 = nn.LayerNorm(hidden_dim*16)

View File

@ -5,33 +5,35 @@ from torch.nn.utils import spectral_norm
from shared.attention import CBAM
class MinamoVisionModel(nn.Module):
def __init__(self, tile_types=32, conv_ch=32, out_dim=128):
def __init__(self, tile_types=32, conv_ch=64, out_dim=512):
super().__init__()
# 输入 softmax 概率值
self.input_conv = nn.Conv2d(tile_types, conv_ch, 3, padding=1)
# 卷积部分
self.vision_conv = nn.Sequential(
spectral_norm(nn.Conv2d(conv_ch, conv_ch*2, 3, padding=1)),
nn.Conv2d(conv_ch, conv_ch*2, 3, padding=1),
nn.BatchNorm2d(conv_ch*2),
CBAM(conv_ch*2),
nn.GELU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.4),
spectral_norm(nn.Conv2d(conv_ch*2, conv_ch*4, 3, padding=1)),
nn.Conv2d(conv_ch*2, conv_ch*4, 3, padding=1),
nn.BatchNorm2d(conv_ch*4),
CBAM(conv_ch*4),
nn.GELU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.4),
spectral_norm(nn.Conv2d(conv_ch*4, conv_ch*8, 3, padding=1)),
nn.Conv2d(conv_ch*4, conv_ch*8, 3, padding=1),
nn.BatchNorm2d(conv_ch*8),
CBAM(conv_ch*8),
nn.GELU(),
# nn.MaxPool2d(2),
# nn.Dropout2d(0.4),
spectral_norm(nn.Conv2d(conv_ch*8, conv_ch*8, 3, padding=1)),
nn.Conv2d(conv_ch*8, conv_ch*8, 3, padding=1),
nn.BatchNorm2d(conv_ch*8),
CBAM(conv_ch*8),
nn.GELU(),

View File

@ -1,4 +1,5 @@
import os
import sys
from datetime import datetime
import torch
import torch.optim as optim
@ -13,6 +14,7 @@ from shared.args import parse_arguments
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
os.makedirs("result/minamo_checkpoint", exist_ok=True)
disable_tqdm = not sys.stdout.isatty() # 如果 stdout 被重定向,则禁用 tqdm
def collate_fn(batch):
"""动态处理不同尺寸地图的批处理"""
@ -56,7 +58,7 @@ def train():
# 设定优化器与调度器
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=3, T_mult=2, eta_min=1e-6)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
criterion = MinamoLoss()
if args.resume:
@ -71,7 +73,7 @@ def train():
# param.requires_grad = False
# 开始训练
for epoch in tqdm(range(args.epochs)):
for epoch in tqdm(range(args.epochs), disable=disable_tqdm):
model.train()
total_loss = 0
@ -79,7 +81,7 @@ def train():
# for name, param in model.named_parameters():
# param.requires_grad = True
for batch in tqdm(dataloader, leave=False):
for batch in tqdm(dataloader, leave=False, disable=disable_tqdm):
# 数据迁移到设备
map1, map2, vision_simi, topo_simi, graph1, graph2 = batch
map1 = map1.to(device) # 转为 [B, C, H, W]
@ -108,7 +110,7 @@ def train():
total_loss += loss.item()
ave_loss = total_loss / len(dataloader)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
print(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
# total_norm = 0
# for p in model.parameters():
@ -130,7 +132,7 @@ def train():
model.eval()
val_loss = 0
with torch.no_grad():
for val_batch in val_loader:
for val_batch in tqdm(val_loader, leave=False, disable=disable_tqdm):
map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = val_batch
map1_val = map1_val.to(device)
map2_val = map2_val.to(device)
@ -150,7 +152,7 @@ def train():
val_loss += loss_val.item()
avg_val_loss = val_loss / len(val_loader)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
print(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
torch.save({
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),

View File

@ -1,6 +1,56 @@
import torch
from torch_geometric.data import Data, Batch
def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data:
"""
可导的图结构转换返回PyG Data对象
map_probs: [C, H, W]
返回
Data(x=[N,C], edge_index=[2,E], edge_attr=[E,C])
"""
C, H, W = map_probs.shape
device = map_probs.device
N = H * W
# 1. 节点特征(保留所有节点)
node_features = map_probs.view(C, -1).T # [N, C]
# 2. 构建所有可能的边连接(预计算)
# 生成坐标网格
rows, cols = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
node_indices = rows * W + cols
# 水平连接(右邻居)
right_src = node_indices[:, :-1].flatten()
right_dst = node_indices[:, 1:].flatten()
# 垂直连接(下邻居)
down_src = node_indices[:-1, :].flatten()
down_dst = node_indices[1:, :].flatten()
# 合并边列表
edge_src = torch.cat([right_src, down_src]).to(device)
edge_dst = torch.cat([right_dst, down_dst]).to(device)
edge_index = torch.stack([edge_src, edge_dst]) # [2, E]
# 3. 计算可导的边权重(排除墙类型)
wall_class_idx = 1 # 假设类型1是墙
src_probs = 1.0 - map_probs[wall_class_idx].flatten()[edge_src] # [E]
dst_probs = 1.0 - map_probs[wall_class_idx].flatten()[edge_dst] # [E]
edge_mask = (src_probs * dst_probs).unsqueeze(1) # [E, 1]
# 4. 边特征计算(保持可导)
src_feat = map_probs[:, edge_src//W, edge_src%W].T # [E, C]
dst_feat = map_probs[:, edge_dst//W, edge_dst%W].T # [E, C]
edge_attr = (src_feat + dst_feat) / 2 * edge_mask # [E, C]
return Data(
x=node_features,
edge_index=edge_index,
edge_attr=edge_attr,
num_nodes=N
)
def convert_soft_map_to_graph(map_probs: torch.Tensor):
"""
直接使用 Softmax 概率构建 soft 图结构
@ -40,7 +90,7 @@ def batch_convert_soft_map_to_graph(batch_map_probs):
batch_graphs = []
for i in range(B):
graph = convert_soft_map_to_graph(batch_map_probs[i]) # 处理单个样本
graph = differentiable_convert_to_data(batch_map_probs[i]) # 处理单个样本
batch_graphs.append(graph)
# 合并所有图为批量 Batch