perf: 改进 WGAN 训练

This commit is contained in:
unanmed 2025-04-10 22:42:58 +08:00
parent 268b21e0b7
commit 99f46150be
12 changed files with 711 additions and 127 deletions

View File

@ -7,6 +7,8 @@ from torch_geometric.data import Data
from minamo.model.model import MinamoModel from minamo.model.model import MinamoModel
from shared.graph import batch_convert_soft_map_to_graph from shared.graph import batch_convert_soft_map_to_graph
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
from shared.similarity.topo import overall_similarity, build_topological_graph
from shared.similarity.vision import calculate_visual_similarity
CLASS_NUM = 32 CLASS_NUM = 32
ILLEGAL_MAX_NUM = 12 ILLEGAL_MAX_NUM = 12
@ -286,32 +288,21 @@ def interpolate_graph_features(real_graph, fake_graph, epsilon=0.5):
return Data(x=x_interp, edge_index=edge_index_interp, edge_attr=edge_attr_interp) return Data(x=x_interp, edge_index=edge_index_interp, edge_attr=edge_attr_interp)
def js_divergence(P, Q, epsilon=1e-10): def js_divergence(p, q, eps=1e-8):
""" # softmax 后变成概率分布
输入: m = 0.5 * (p + q)
P, Q: [B, C, H, W], 已通过 Softmax 处理
输出:
JS 散度标量全局平均
"""
# 转换为 [B, H, W, C] 以便在最后一维计算概率分布
P = P.permute(0, 2, 3, 1) # [B, H, W, C]
Q = Q.permute(0, 2, 3, 1)
# 平均分布 M = (P + Q)/2 # log_softmax 以供 kl_div 使用
M = 0.5 * (P + Q) log_p = torch.log(p + eps)
log_q = torch.log(q + eps)
# 计算 KL(P||M) 和 KL(Q||M) kl_pm = F.kl_div(log_p, m, reduction='batchmean', log_target=False) # KL(p || m)
kl_pm = F.kl_div(torch.log(M + epsilon), P, reduction='none', log_target=False).sum(dim=-1) # [B, H, W] kl_qm = F.kl_div(log_q, m, reduction='batchmean', log_target=False) # KL(q || m)
kl_qm = F.kl_div(torch.log(M + epsilon), Q, reduction='none', log_target=False).sum(dim=-1) # [B, H, W]
# JS 散度 = 0.5*(KL(P||M) + KL(Q||M)) return torch.clamp(0.5 * (kl_pm + kl_qm), max=1.0)
js = 0.5 * (kl_pm + kl_qm)
# 全局平均(可替换为其他聚合方式)
return js.mean() # 标量
class WGANGinkaLoss: class WGANGinkaLoss:
def __init__(self, lambda_gp=50, weight=[0.7, 0.2, 0.1], diversity_lamda=0.2): def __init__(self, lambda_gp=100, weight=[0.8, 0.1, 0.1], diversity_lamda=0.4):
self.lambda_gp = lambda_gp # 梯度惩罚系数 self.lambda_gp = lambda_gp # 梯度惩罚系数
self.weight = weight self.weight = weight
self.diversity_lamda = diversity_lamda self.diversity_lamda = diversity_lamda
@ -369,8 +360,50 @@ class WGANGinkaLoss:
return d_loss, d_loss + self.lambda_gp * grad_loss return d_loss, d_loss + self.lambda_gp * grad_loss
def generator_loss_one(self, critic, fake): def calculate_similarity_one(self, map1, map2):
fake_graph = batch_convert_soft_map_to_graph(fake) topo1 = build_topological_graph(map1)
topo2 = build_topological_graph(map2)
vis_sim = calculate_visual_similarity(map1, map2)
topo_sim = overall_similarity(topo1, topo2)
return vis_sim, topo_sim
def discriminator_loss_assist(self, critic, fake_data1, fake_data2):
graph1 = batch_convert_soft_map_to_graph(fake_data1)
graph2 = batch_convert_soft_map_to_graph(fake_data2)
vis_feat_1, topo_feat_1 = critic(fake_data1, graph1)
vis_feat_2, topo_feat_2 = critic(fake_data2, graph2)
batch1 = torch.argmax(fake_data1, dim=1).cpu().tolist()
batch2 = torch.argmax(fake_data2, dim=1).cpu().tolist()
vis_sim_real = []
topo_sim_real = []
for i in range(len(batch1)):
vis_sim, topo_sim = self.calculate_similarity_one(batch1[i], batch2[i])
vis_sim_real.append(vis_sim)
topo_sim_real.append(topo_sim)
vis_sim_real = torch.Tensor(vis_sim_real)
topo_sim_real = torch.Tensor(topo_sim_real)
pred_vis_sim = F.cosine_similarity(vis_feat_1, vis_feat_2).cpu()
pred_topo_sim = F.cosine_similarity(topo_feat_1, topo_feat_2).cpu()
loss1 = F.l1_loss(pred_vis_sim, vis_sim_real) * VISION_WEIGHT + F.l1_loss(pred_topo_sim, topo_sim_real) * TOPO_WEIGHT
return loss1
def discriminator_loss_assist2(self, critic, real_data, fake_data1, fake_data2):
loss1 = self.discriminator_loss_assist(critic, real_data, fake_data1)
loss2 = self.discriminator_loss_assist(critic, real_data, fake_data2)
loss3 = self.discriminator_loss_assist(critic, fake_data1, fake_data2)
return loss1 / 3.0 + loss2 / 3.0 + loss3 / 3.0
def generator_loss_one(self, critic, fake, fake_graph):
fake_scores, _, _ = critic(fake, fake_graph) fake_scores, _, _ = critic(fake, fake_graph)
minamo_loss = -torch.mean(fake_scores) minamo_loss = -torch.mean(fake_scores)
class_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake) class_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
@ -384,15 +417,24 @@ class WGANGinkaLoss:
return sum(losses) return sum(losses)
def diversity_loss(self, fake1, fake2): def generator_loss(self, critic, critic_assist, fake1, fake2):
fake1 = fake1[:, :, 1:-1, 1:-1]
fake2 = fake2[:, :, 1:-1, 1:-1]
return js_divergence(fake1, fake2)
def generator_loss(self, critic, fake1, fake2):
""" 生成器损失函数 """ """ 生成器损失函数 """
loss1 = self.generator_loss_one(critic, fake1) fake_graph1 = batch_convert_soft_map_to_graph(fake1)
loss2 = self.generator_loss_one(critic, fake2) fake_graph2 = batch_convert_soft_map_to_graph(fake2)
return loss1 * 0.5 + loss2 * 0.5 - self.diversity_lamda * self.diversity_loss(fake1, fake2) loss1 = self.generator_loss_one(critic, fake1, fake_graph1)
loss2 = self.generator_loss_one(critic, fake2, fake_graph2)
# vis_feat1, topo_feat1 = critic_assist(fake1, fake_graph1)
# vis_feat2, topo_feat2 = critic_assist(fake2, fake_graph2)
# vis_sim = F.cosine_similarity(vis_feat1, vis_feat2)
# topo_sim = F.cosine_similarity(topo_feat1, topo_feat2)
# similarity = vis_sim * VISION_WEIGHT + topo_sim * TOPO_WEIGHT
# print(similarity.mean().item())
# div_loss = F.l1_loss(fake1[:, :, 1:-1, 1:-1], fake2[:, :, 1:-1, 1:-1])
return loss1 * 0.5 + loss2 * 0.5\
# + self.diversity_lamda * F.relu(0.7 - div_loss).mean()
# + self.diversity_lamda * F.relu(similarity - 0.4).mean()

View File

@ -37,12 +37,12 @@ if __name__ == "__main__":
print_memory("初始化后") print_memory("初始化后")
# 前向传播 # 前向传播
output, output_softmax = model(feat) output = model(feat)
print_memory("前向传播后") print_memory("前向传播后")
print(f"输入形状: feat={feat.shape}") print(f"输入形状: feat={feat.shape}")
print(f"输出形状: output={output.shape}, softmax={output_softmax.shape}") print(f"输出形状: output={output.shape}")
# print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}") # print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}")
# print(f"Feature Encoder parameters: {sum(p.numel() for p in model.feat_enc.parameters())}") # print(f"Feature Encoder parameters: {sum(p.numel() for p in model.feat_enc.parameters())}")
print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}") print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}")

View File

@ -1,7 +1,9 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from shared.constant import FEAT_DIM from torch_geometric.nn import GCNConv
from torch_geometric.utils import grid
from shared.attention import ChannelAttention
class GinkaTransformerEncoder(nn.Module): class GinkaTransformerEncoder(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim, token_size, ff_dim, num_heads=8, num_layers=6): def __init__(self, in_dim, hidden_dim, out_dim, token_size, ff_dim, num_heads=8, num_layers=6):
@ -33,7 +35,7 @@ class GinkaTransformerEncoder(nn.Module):
return x return x
class ConvBlock(nn.Module): class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch): def __init__(self, in_ch, out_ch, atte=True):
super().__init__() super().__init__()
self.conv = nn.Sequential( self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'), nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'),
@ -41,12 +43,70 @@ class ConvBlock(nn.Module):
nn.ELU(), nn.ELU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'), nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(out_ch), nn.InstanceNorm2d(out_ch),
nn.ELU(),
) )
if atte:
self.conv.append(ChannelAttention(out_ch))
self.conv.append(nn.ELU())
def forward(self, x): def forward(self, x):
return self.conv(x) return self.conv(x)
class GCNBlock(nn.Module):
def __init__(self, in_ch, hidden_ch, out_ch, w, h):
super().__init__()
self.conv1 = GCNConv(in_ch, hidden_ch)
self.conv2 = GCNConv(hidden_ch, out_ch)
self.norm1 = nn.LayerNorm(hidden_ch)
self.norm2 = nn.LayerNorm(out_ch)
self.single_edge_index, _ = grid(h, w) # [2, E] for a single map
def forward(self, x):
# x: [B, C, H, W]
B, C, H, W = x.shape
# Reshape to [B * H * W, C]
x = x.permute(0, 2, 3, 1).reshape(B * H * W, C)
# Construct batched edge index
device = x.device
edge_index = self._batch_edge_index(B, self.single_edge_index.to(device), H * W)
# Batch vector for PyG (not strictly needed for GCNConv, but useful if you switch to GAT/Pooling)
# batch = torch.arange(B, device=device).repeat_interleave(H * W)
# GCN forward
x = self.conv1(x, edge_index)
x = F.elu(self.norm1(x))
x = self.conv2(x, edge_index)
x = F.elu(self.norm2(x))
# Reshape back to [B, C, H, W]
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
return x
def _batch_edge_index(self, B, edge_index, num_nodes_per_batch):
# 批次偏移 edge_index
edge_index = edge_index.clone() # [2, E]
batch_edge_index = []
for i in range(B):
offset = i * num_nodes_per_batch
batch_edge_index.append(edge_index + offset)
return torch.cat(batch_edge_index, dim=1)
class FusionModule(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 1),
nn.InstanceNorm2d(out_ch),
nn.ELU()
)
def forward(self, x1, x2):
x = torch.cat([x1, x2], dim=1)
x = self.conv(x)
return x
class GinkaEncoder(nn.Module): class GinkaEncoder(nn.Module):
"""编码器(下采样)部分""" """编码器(下采样)部分"""
def __init__(self, in_ch, out_ch): def __init__(self, in_ch, out_ch):
@ -59,6 +119,21 @@ class GinkaEncoder(nn.Module):
x = self.pool(x) x = self.pool(x)
return x return x
class GinkaGCNFusedEncoder(nn.Module):
def __init__(self, in_ch, out_ch, w, h):
super().__init__()
self.conv = ConvBlock(in_ch, out_ch)
self.gcn = GCNBlock(out_ch, out_ch*2, out_ch, w, h)
self.pool = nn.MaxPool2d(2)
self.fusion = FusionModule(out_ch*2, out_ch)
def forward(self, x):
x = self.conv(x)
x = self.pool(x)
x2 = self.gcn(x)
x = self.fusion(x, x2)
return x
class GinkaUpSample(nn.Module): class GinkaUpSample(nn.Module):
def __init__(self, in_ch, out_ch): def __init__(self, in_ch, out_ch):
super().__init__() super().__init__()
@ -84,6 +159,44 @@ class GinkaDecoder(nn.Module):
x = self.conv(x) x = self.conv(x)
return x return x
class GinkaGCNFusedDecoder(nn.Module):
def __init__(self, in_ch, out_ch, w, h):
super().__init__()
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
self.conv = ConvBlock(in_ch, out_ch)
self.gcn = GCNBlock(out_ch, out_ch*2, out_ch, w, h)
self.fusion = FusionModule(out_ch*2, out_ch)
def forward(self, x, feat):
x = self.upsample(x)
x = torch.cat([x, feat], dim=1)
x = self.conv(x)
x2 = self.gcn(x)
x = self.fusion(x, x2)
return x
class GinkaBottleneck(nn.Module):
def __init__(self, module_ch, w, h):
super().__init__()
self.transformer = GinkaTransformerEncoder(
in_dim=module_ch*w*h, hidden_dim=module_ch*w*h, out_dim=module_ch*w*h,
token_size=16, ff_dim=1024, num_layers=4
)
self.gcn = GCNBlock(module_ch, module_ch*2, module_ch, 4, 4)
self.fusion = FusionModule(module_ch*2, module_ch)
def forward(self, x):
B = x.size(0)
x1 = x.view(B, 512, 16).permute(0, 2, 1) # [B, 16, in_ch]
x1 = self.transformer(x1)
x1 = x1.permute(0, 2, 1).view(B, 512, 4, 4) # [B, out_ch, 4, 4]
x2 = self.gcn(x)
x = self.fusion(x1, x2)
return x
class GinkaUNet(nn.Module): class GinkaUNet(nn.Module):
def __init__(self, base_ch=64, out_ch=32, feat_dim=1024): def __init__(self, base_ch=64, out_ch=32, feat_dim=1024):
"""Ginka Model UNet 部分 """Ginka Model UNet 部分
@ -94,17 +207,14 @@ class GinkaUNet(nn.Module):
token_size=4, ff_dim=feat_dim*2, num_layers=4 token_size=4, ff_dim=feat_dim*2, num_layers=4
) )
self.down1 = ConvBlock(2, base_ch) self.down1 = ConvBlock(2, base_ch)
self.down2 = GinkaEncoder(base_ch, base_ch*2) self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16)
self.down3 = GinkaEncoder(base_ch*2, base_ch*4) self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8)
self.down4 = GinkaEncoder(base_ch*4, base_ch*8) self.down4 = GinkaEncoder(base_ch*4, base_ch*8)
self.bottleneck = GinkaTransformerEncoder( self.bottleneck = GinkaBottleneck(base_ch*8, 4, 4)
in_dim=base_ch*8*4*4, hidden_dim=base_ch*8*4*4, out_dim=base_ch*8*4*4,
token_size=16, ff_dim=1024, num_layers=4
)
self.up1 = GinkaDecoder(base_ch*8, base_ch*4) self.up1 = GinkaGCNFusedDecoder(base_ch*8, base_ch*4, 8, 8)
self.up2 = GinkaDecoder(base_ch*4, base_ch*2) self.up2 = GinkaGCNFusedDecoder(base_ch*4, base_ch*2, 16, 16)
self.up3 = GinkaDecoder(base_ch*2, base_ch) self.up3 = GinkaGCNFusedDecoder(base_ch*2, base_ch, 32, 32)
self.final = nn.Sequential( self.final = nn.Sequential(
nn.Conv2d(base_ch, out_ch, 1), nn.Conv2d(base_ch, out_ch, 1),
@ -121,9 +231,7 @@ class GinkaUNet(nn.Module):
x2 = self.down2(x1) # [B, 128, 16, 16] x2 = self.down2(x1) # [B, 128, 16, 16]
x3 = self.down3(x2) # [B, 256, 8, 8] x3 = self.down3(x2) # [B, 256, 8, 8]
x4 = self.down4(x3) # [B, 512, 4, 4] x4 = self.down4(x3) # [B, 512, 4, 4]
x4 = x4.view(B, 512, 16).permute(0, 2, 1) # [B, 16, 512] x4 = self.bottleneck(x4) # [B, 512, 4, 4]
x4 = self.bottleneck(x4) # [B, 16, 512]
x4 = x4.permute(0, 2, 1).view(B, 512, 4, 4) # [B, 512, 4, 4]
# 上采样 # 上采样
x = self.up1(x4, x3) # [B, 256, 8, 8] x = self.up1(x4, x3) # [B, 256, 8, 8]

View File

@ -11,6 +11,7 @@ from .model.model import GinkaModel
from .dataset import GinkaWGANDataset from .dataset import GinkaWGANDataset
from .model.loss import WGANGinkaLoss from .model.loss import WGANGinkaLoss
from minamo.model.model import MinamoScoreModule from minamo.model.model import MinamoScoreModule
from minamo.model.similarity import MinamoSimilarityModel
from shared.graph import batch_convert_soft_map_to_graph from shared.graph import batch_convert_soft_map_to_graph
from shared.image import matrix_to_image_cv from shared.image import matrix_to_image_cv
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
@ -26,10 +27,12 @@ disable_tqdm = not sys.stdout.isatty()
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser(description="training codes") parser = argparse.ArgumentParser(description="training codes")
parser.add_argument("--resume", type=bool, default=False) parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--state_ginka", type=str, default="result/ginka.pth") parser.add_argument("--state_ginka", type=str, default="result/wgan/ginka-100.pth")
parser.add_argument("--state_minamo", type=str, default="result/minamo.pth") parser.add_argument("--state_minamo", type=str, default="result/wgan/minamo-100.pth")
parser.add_argument("--train", type=str, default="ginka-dataset.json") parser.add_argument("--train", type=str, default="ginka-dataset.json")
parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--checkpoint", type=int, default=5)
parser.add_argument("--load_optim", type=bool, default=True)
args = parser.parse_args() args = parser.parse_args()
return args return args
@ -49,14 +52,20 @@ def train():
ginka = GinkaModel() ginka = GinkaModel()
minamo = MinamoScoreModule() minamo = MinamoScoreModule()
minamo_sim = MinamoSimilarityModel()
ginka.to(device) ginka.to(device)
minamo.to(device) minamo.to(device)
minamo_sim.to(device)
dataset = GinkaWGANDataset(args.train, device) dataset = GinkaWGANDataset(args.train, device)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9)) optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9)) optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9))
optimizer_minamo_sim = optim.Adam(minamo_sim.parameters(), lr=1e-4)
# scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs)
# scheduler_minamo = optim.lr_scheduler.CosineAnnealingLR(optimizer_minamo, T_max=args.epochs)
criterion = WGANGinkaLoss() criterion = WGANGinkaLoss()
@ -67,14 +76,29 @@ def train():
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED) tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
if args.resume: if args.resume:
data = torch.load(args.state_ginka, map_location=device) data_ginka = torch.load(args.state_ginka, map_location=device)
ginka.load_state_dict(data["model_state"], strict=False) data_minamo = torch.load(args.state_minamo, map_location=device)
data = torch.load(args.state_minamo, map_location=device)
minamo.load_state_dict(data["model_state"], strict=False) ginka.load_state_dict(data_ginka["model_state"], strict=False)
minamo.load_state_dict(data_minamo["model_state"], strict=False)
if data_ginka["c_steps"] is not None and data_ginka["g_steps"] is not None:
c_steps = data_ginka["c_steps"]
g_steps = data_ginka["g_steps"]
if args.load_optim:
if data_ginka["optim_state"] is not None:
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
if data_minamo["optim_state"] is not None:
optimizer_minamo.load_state_dict(data_minamo["optim_state"])
if data_minamo["optim_state_sim"] is not None:
optimizer_minamo_sim.load_state_dict(data_minamo["optim_state_sim"])
print("Train from loaded state.") print("Train from loaded state.")
for epoch in tqdm(range(args.epochs), desc="GAN Training", disable=disable_tqdm): for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm):
loss_total_minamo = torch.Tensor([0]).to(device) loss_total_minamo = torch.Tensor([0]).to(device)
loss_total_minamo_sim = torch.Tensor([0]).to(device)
loss_total_ginka = torch.Tensor([0]).to(device) loss_total_ginka = torch.Tensor([0]).to(device)
dis_total = torch.Tensor([0]).to(device) dis_total = torch.Tensor([0]).to(device)
@ -83,13 +107,11 @@ def train():
real_data = real_data.to(device) real_data = real_data.to(device)
real_graph = batch_convert_soft_map_to_graph(real_data) real_graph = batch_convert_soft_map_to_graph(real_data)
optimizer_ginka.zero_grad()
# ---------- 训练判别器 # ---------- 训练判别器
for _ in range(c_steps): for _ in range(c_steps):
# 生成假样本 # 生成假样本
optimizer_minamo.zero_grad() optimizer_minamo.zero_grad()
z = torch.randn(batch_size, 1024, device=device) z = torch.rand(batch_size, 1024, device=device)
fake_data = ginka(z) fake_data = ginka(z)
fake_data = fake_data.detach() fake_data = fake_data.detach()
@ -97,59 +119,65 @@ def train():
# 反向传播 # 反向传播
dis, loss_d = criterion.discriminator_loss(minamo, real_data, real_graph, fake_data) dis, loss_d = criterion.discriminator_loss(minamo, real_data, real_graph, fake_data)
loss_d.backward() loss_d.backward()
# torch.nn.utils.clip_grad_norm_(minamo.parameters(), max_norm=1.0) # torch.nn.utils.clip_grad_norm_(minamo.parameters(), max_norm=2.0)
# total_norm = torch.linalg.vector_norm(torch.stack([torch.linalg.vector_norm(p.grad) for p in minamo.topo_model.parameters()]), 2) # total_norm = torch.linalg.vector_norm(torch.stack([torch.linalg.vector_norm(p.grad) for p in minamo.topo_model.parameters()]), 2)
# print("Critic 梯度范数:", total_norm.item()) # print("Critic 梯度范数:", total_norm.item())
# print("Critic 输入范围:", fake_data.min().item(), fake_data.max().item(), real_data.min().item(), real_data.max().item()) # print("Critic 输入范围:", fake_data.min().item(), fake_data.max().item(), real_data.min().item(), real_data.max().item())
# print("Critic 输出范围:", d_real.min().item(), d_real.max().item()) # print("Critic 输出范围:", d_real.min().item(), d_real.max().item())
optimizer_minamo.step() optimizer_minamo.step()
loss_total_minamo += loss_d loss_total_minamo += loss_d.detach()
dis_total += dis dis_total += dis.detach()
# ---------- 训练生成器 # ---------- 训练生成器
for _ in range(g_steps): for _ in range(g_steps):
optimizer_ginka.zero_grad()
# optimizer_minamo_sim.zero_grad()
z1 = torch.randn(batch_size, 1024, device=device) z1 = torch.randn(batch_size, 1024, device=device)
z2 = torch.randn(batch_size, 1024, device=device) z2 = torch.randn(batch_size, 1024, device=device)
fake_softmax1, fake_softmax2 = ginka(z1), ginka(z2) fake_softmax1, fake_softmax2 = ginka(z1), ginka(z2)
loss_g = criterion.generator_loss(minamo, fake_softmax1, fake_softmax2) # 先训练辅助判别器
# loss_c_assist = criterion.discriminator_loss_assist2(minamo_sim, real_data, fake_softmax1, fake_softmax2)
# loss_c_assist.backward(retain_graph=True)
# optimizer_minamo_sim.step()
loss_g = criterion.generator_loss(minamo, minamo_sim, fake_softmax1, fake_softmax2)
loss_g.backward() loss_g.backward()
optimizer_ginka.step() optimizer_ginka.step()
loss_total_ginka += loss_g loss_total_ginka += loss_g
# loss_total_minamo_sim += loss_c_assist.detach()
# tqdm.write(f"{dis.item():.12f}, {loss_d.item():.12f}, {loss_g.item():.12f}") # tqdm.write(f"{dis.item():.12f}, {loss_d.item():.12f}, {loss_g.item():.12f}")
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
avg_loss_minamo_sim = loss_total_minamo_sim.item() / len(dataloader) / g_steps
avg_dis = dis_total.item() / len(dataloader) / c_steps avg_dis = dis_total.item() / len(dataloader) / c_steps
tqdm.write( tqdm.write(
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | Wasserstein Loss: {avg_dis:.8f} | Loss Ginka: {avg_loss_ginka:.8f} | Loss Minamo: {avg_loss_minamo:.8f}" f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +\
f"Epoch: {epoch + 1} | W Loss: {avg_dis:.8f} | " +\
f"G Loss: {avg_loss_ginka:.8f} | D Loss: {avg_loss_minamo:.8f} | " +\
f"lr G: {(optimizer_ginka.param_groups[0]['lr']):.8f}"
) )
# scheduler_ginka.step()
# scheduler_minamo.step()
if avg_dis < 0: if avg_dis < 0:
g_steps = max(int(-avg_dis * 5), 1) g_steps = max(int(-avg_dis * 5), 1)
else: else:
g_steps = 1 g_steps = 1
# if avg_dis > 0: if avg_loss_ginka > 0 or avg_loss_minamo > 0:
# c_steps = min(max(int(avg_dis * 5), 1), 5) c_steps = min(5 + (avg_loss_ginka + avg_loss_minamo) * 5, 15)
# else: else:
# c_steps = 1 c_steps = 5
# if avg_loss_minamo > 0:
# c_steps += min(max(int(avg_loss_minamo * 3), 1), 5)
# else:
# c_steps += 0
# if avg_dis > 3:
# c_steps = 3
# else:
# c_steps = 1
# 每若干轮输出一次图片,并保存检查点 # 每若干轮输出一次图片,并保存检查点
if (epoch + 1) % 5 == 0: if (epoch + 1) % args.checkpoint == 0:
# 输出 20 张图片,每批次 4 张,一共五批 # 输出 20 张图片,每批次 4 张,一共五批
idx = 0 idx = 0
with torch.no_grad(): with torch.no_grad():
@ -165,18 +193,25 @@ def train():
# 保存检查点 # 保存检查点
torch.save({ torch.save({
"model_state": ginka.state_dict() "model_state": ginka.state_dict(),
"optim_state": optimizer_ginka.state_dict(),
"c_steps": c_steps,
"g_steps": g_steps
}, f"result/wgan/ginka-{epoch + 1}.pth") }, f"result/wgan/ginka-{epoch + 1}.pth")
torch.save({ torch.save({
"model_state": minamo.state_dict() "model_state": minamo.state_dict(),
"model_state_sim": minamo_sim.state_dict(),
"optim_state": optimizer_minamo.state_dict(),
"optim_state_sim": optimizer_minamo_sim.state_dict()
}, f"result/wgan/minamo-{epoch + 1}.pth") }, f"result/wgan/minamo-{epoch + 1}.pth")
print("Train ended.") print("Train ended.")
torch.save({ torch.save({
"model_state": ginka.state_dict() "model_state": ginka.state_dict(),
}, f"result/ginka.pth") }, f"result/ginka.pth")
torch.save({ torch.save({
"model_state": minamo.state_dict() "model_state": minamo.state_dict(),
"model_state_sim": minamo_sim.state_dict(),
}, f"result/minamo.pth") }, f"result/minamo.pth")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -20,40 +20,6 @@ class MinamoModel(nn.Module):
return vision_feat, topo_feat return vision_feat, topo_feat
class MinamoVisionScore(nn.Module):
def __init__(self, tile_types=32):
super().__init__()
# 视觉相似度部分
self.vision_model = MinamoVisionModel(tile_types)
# 输出层
self.vision_fc = nn.Sequential(
nn.Linear(512, 2048),
nn.LeakyReLU(0.2),
nn.Linear(2048, 1)
)
def forward(self, map):
vision_feat = self.vision_model(map)
vis_score = self.vision_fc(vision_feat)
return vis_score
class MinamoTopoScore(nn.Module):
def __init__(self, tile_types=32):
super().__init__()
# 拓扑相似度部分
self.topo_model = MinamoTopoModel(tile_types)
# 输出层
self.topo_fc = nn.Sequential(
nn.Linear(512, 2048),
nn.LeakyReLU(0.2),
nn.Linear(2048, 1)
)
def forward(self, graph):
topo_feat = self.topo_model(graph)
topo_score = self.topo_fc(topo_feat)
return topo_score
class MinamoScoreModule(nn.Module): class MinamoScoreModule(nn.Module):
def __init__(self, tile_types=32): def __init__(self, tile_types=32):
super().__init__() super().__init__()

View File

@ -0,0 +1,83 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data
class MinamoSimilarityVision(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, in_ch * 2, 3, padding=1),
nn.InstanceNorm2d(in_ch * 2),
nn.ReLU(),
nn.Conv2d(in_ch * 2, in_ch * 4, 3, padding=1),
nn.InstanceNorm2d(in_ch * 4),
nn.ReLU(),
nn.Conv2d(in_ch * 4, in_ch * 8, 3),
nn.InstanceNorm2d(in_ch * 8),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1)
)
self.fc = nn.Sequential(
nn.Linear(in_ch * 8, out_ch),
)
def forward(self, x):
x = self.conv(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
class MinamoSimilarityTopo(nn.Module):
def __init__(self, in_ch, hidden_dim, out_ch):
super().__init__()
self.input_fc = nn.Sequential(
nn.Linear(in_ch, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
)
self.conv1 = GCNConv(hidden_dim, hidden_dim*2)
self.conv2 = GCNConv(hidden_dim*2, hidden_dim*4)
self.conv3 = GCNConv(hidden_dim*4, hidden_dim*8)
self.norm1 = nn.LayerNorm(hidden_dim*2)
self.norm2 = nn.LayerNorm(hidden_dim*4)
self.norm3 = nn.LayerNorm(hidden_dim*8)
self.output_fc = nn.Sequential(
nn.Linear(hidden_dim*8, out_ch)
)
def forward(self, graph: Data):
x = self.input_fc(graph.x)
x = self.conv1(x, graph.edge_index)
x = F.relu(self.norm1(x))
x = self.conv2(x, graph.edge_index)
x = F.relu(self.norm2(x))
x = self.conv3(x, graph.edge_index)
x = F.relu(self.norm3(x))
x = global_mean_pool(x, graph.batch)
x = self.output_fc(x)
return x
class MinamoSimilarityModel(nn.Module):
def __init__(self, tile_type=32):
super().__init__()
self.vision = MinamoSimilarityVision(tile_type, 512)
self.topo = MinamoSimilarityTopo(tile_type, 64, 512)
def forward(self, x, graph):
vis_feat = self.vision(x)
topo_feat = self.topo(graph)
return vis_feat, topo_feat

View File

@ -7,16 +7,16 @@ class MinamoVisionModel(nn.Module):
def __init__(self, in_ch=32, out_dim=512): def __init__(self, in_ch=32, out_dim=512):
super().__init__() super().__init__()
self.conv = nn.Sequential( self.conv = nn.Sequential(
spectral_norm(nn.Conv2d(in_ch, in_ch*2, 3, stride=2)), # 6*6 spectral_norm(nn.Conv2d(in_ch, in_ch*2, 3)), # 11*11
nn.LeakyReLU(0.2), nn.LeakyReLU(0.2),
spectral_norm(nn.Conv2d(in_ch*2, in_ch*4, 3)), #4*4 spectral_norm(nn.Conv2d(in_ch*2, in_ch*4, 3)), #9*9
nn.LeakyReLU(0.2), nn.LeakyReLU(0.2),
spectral_norm(nn.Conv2d(in_ch*4, in_ch*8, 3)), # 2*2 spectral_norm(nn.Conv2d(in_ch*4, in_ch*8, 3)), # 7*7
nn.LeakyReLU(0.2), nn.LeakyReLU(0.2),
nn.Flatten() nn.AdaptiveAvgPool2d(2)
) )
self.fc = nn.Sequential( self.fc = nn.Sequential(
spectral_norm(nn.Linear(in_ch*8*2*2, out_dim)), spectral_norm(nn.Linear(in_ch*8*2*2, out_dim)),
@ -25,5 +25,6 @@ class MinamoVisionModel(nn.Module):
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
x = x.view(x.size(0), -1)
x = self.fc(x) x = self.fc(x)
return x return x

View File

@ -9,7 +9,7 @@ class ChannelAttention(nn.Module):
self.channel_att = nn.Sequential( self.channel_att = nn.Sequential(
nn.AdaptiveAvgPool2d(1), nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, channels//reduction, 1), nn.Conv2d(channels, channels//reduction, 1),
nn.GELU(), nn.ELU(),
nn.Conv2d(channels//reduction, channels, 1), nn.Conv2d(channels//reduction, channels, 1),
nn.Sigmoid() nn.Sigmoid()
) )

View File

@ -2,5 +2,5 @@ VIS_DIM = 512
TOPO_DIM = 512 TOPO_DIM = 512
FEAT_DIM = 1024 FEAT_DIM = 1024
VISION_WEIGHT = 0.3 VISION_WEIGHT = 0.2
TOPO_WEIGHT = 0.7 TOPO_WEIGHT = 0.8

273
shared/similarity/topo.py Normal file
View File

@ -0,0 +1,273 @@
# Converted Python version of the JS code
import math
from typing import Dict, Set, List, Tuple, Union
from collections import deque, defaultdict
# 拓扑相似度,由 ChatGPT-4o 从 ts 转译而来
class ResourceArea:
def __init__(self):
self.type = 'resource'
self.resources: Dict[int, int] = {}
self.members: Set[int] = set()
self.neighbor: Set[int] = set()
class BranchNode:
def __init__(self, tile: int):
self.type = 'branch'
self.tile = tile
self.neighbor: Set[int] = set()
class ResourceNode:
def __init__(self, resource_type: int, area: ResourceArea):
self.type = 'resource'
self.resourceType = resource_type
self.neighbor = area.neighbor
self.resourceArea = area
GinkaNode = Union[BranchNode, ResourceNode]
class GinkaGraph:
def __init__(self):
self.graph: Dict[int, GinkaNode] = {}
self.resourceMap: Dict[int, int] = {}
self.areaMap: List[ResourceArea] = []
self.visitedEntrance: Set[int] = set()
self.visited: Set[int] = set()
class GinkaTopologicalGraphs:
def __init__(self):
self.graphs: List[GinkaGraph] = []
self.entranceMap: Dict[int, GinkaGraph] = {}
self.unreachable: Set[int] = set()
TILE_TYPE = set(range(13))
BRANCH_TYPE = {6, 7, 8, 9}
ENTRANCE_TYPE = {10, 11}
RESOURCE_TYPE = {0, 2, 3, 4, 5, 10, 11, 12}
directions: List[Tuple[int, int]] = [
(-1, 0), (1, 0), (0, -1), (0, 1)
]
def find_resource_nodes(map_: List[List[int]]):
width, height = len(map_[0]), len(map_)
visited = set()
areas = []
resource_map = {}
for ny in range(height):
for nx in range(width):
tile = map_[ny][nx]
index = ny * width + nx
if index in visited or tile not in RESOURCE_TYPE:
continue
queue = deque([(nx, ny)])
area = ResourceArea()
area.resources[tile] = 1
area.members.add(index)
while queue:
cx, cy = queue.popleft()
cindex = cy * width + cx
if cindex in visited:
continue
ctile = map_[cy][cx]
if ctile not in RESOURCE_TYPE:
continue
visited.add(cindex)
area.resources[ctile] = area.resources.get(ctile, 0) + 1
area.members.add(cindex)
resource_map[cindex] = len(areas)
for dx, dy in directions:
px, py = cx + dx, cy + dy
if 0 <= px < width and 0 <= py < height:
queue.append((px, py))
areas.append(area)
return areas, resource_map
def build_graph_from_entrance(map_: List[List[int]], entrance: int, resource_map: Dict[int, int], area_map: List[ResourceArea]) -> GinkaGraph:
width, height = len(map_[0]), len(map_)
graph = GinkaGraph()
graph.resourceMap = resource_map
graph.areaMap = area_map
visited = graph.visited
visited_entrance = graph.visitedEntrance
visited_entrance.add(entrance)
branch_nodes = set()
queue = deque([(entrance % width, entrance // width)])
while queue:
x, y = queue.popleft()
index = y * width + x
if index in visited:
continue
tile = map_[y][x]
if tile in ENTRANCE_TYPE:
visited_entrance.add(index)
if tile in BRANCH_TYPE:
branch_nodes.add(index)
visited.add(index)
for dx, dy in directions:
px, py = x + dx, y + dy
if 0 <= px < width and 0 <= py < height and map_[py][px] != 1:
queue.append((px, py))
for v in branch_nodes:
x, y = v % width, v // width
if v not in graph.graph:
graph.graph[v] = BranchNode(map_[y][x])
node = graph.graph[v]
for dx, dy in directions:
px, py = x + dx, y + dy
if 0 <= px < width and 0 <= py < height:
index = py * width + px
if index in branch_nodes:
node.neighbor.add(index)
elif index in resource_map:
area = area_map[resource_map[index]]
area.neighbor.add(v)
for m in area.members:
node.neighbor.add(m)
for area in area_map:
for index in area.members:
x, y = index % width, index // width
tile = map_[y][x]
if tile == 0:
continue
node = ResourceNode(tile, area)
graph.graph[index] = node
return graph
def build_topological_graph(map_: List[List[int]]) -> GinkaTopologicalGraphs:
width, height = len(map_[0]), len(map_)
entrances = set()
entrances = {y * width + x for y in range(height) for x in range(width) if map_[y][x] in ENTRANCE_TYPE}
area_map, resource_map = find_resource_nodes(map_)
top_graph = GinkaTopologicalGraphs()
used_entrance = set()
total_visited = set()
for entrance in entrances:
if entrance in used_entrance:
continue
graph = build_graph_from_entrance(map_, entrance, resource_map, area_map)
top_graph.graphs.append(graph)
for ent in graph.visitedEntrance:
used_entrance.add(ent)
top_graph.entranceMap[ent] = graph
total_visited.update(graph.visited)
for y in range(height):
for x in range(width):
index = y * width + x
if index not in total_visited and map_[y][x] != 1:
top_graph.unreachable.add(index)
return top_graph
class WLNode:
def __init__(self, pos: int, label: str):
self.originalPos = pos
self.originalLabel = label
self.currentLabel = label
self.neighbors: List['WLNode'] = []
def encode_node_labels(graph: GinkaGraph) -> List[WLNode]:
node_map = {}
nodes = []
for pos, node in graph.graph.items():
label = f"B:{node.tile}" if node.type == 'branch' else f"R:{node.resourceType}"
wl_node = WLNode(pos, label)
node_map[pos] = wl_node
nodes.append(wl_node)
for node in nodes:
g_node = graph.graph[node.originalPos]
for neighbor in g_node.neighbor:
if neighbor in node_map:
node.neighbors.append(node_map[neighbor])
return nodes
def weisfeiler_lehman_iteration(nodes: List[WLNode], iterations: int, decay: float = 0.6) -> Dict[str, float]:
label_history = []
for _ in range(iterations):
new_labels = []
for node in nodes:
neighbor_labels = sorted(n.currentLabel for n in node.neighbors)
composite = f"{node.currentLabel}|{','.join(neighbor_labels)}"[:8192]
new_labels.append(composite)
for node, new_label in zip(nodes, new_labels):
node.currentLabel = new_label
label_history.append(new_labels[:])
weight = 1.0
label_counts = defaultdict(float)
for layer in label_history:
for label in layer:
label_counts[label] += weight
weight *= decay
for node in nodes:
label_counts[node.originalLabel] += weight
return dict(label_counts)
def vectorize_features(features: Dict[str, float], vocab: List[str]) -> List[float]:
vec = [0.0] * len(vocab)
for label, count in features.items():
if label in vocab:
idx = vocab.index(label)
vec[idx] += count
return vec
def cosine_similarity(a: List[float], b: List[float]) -> float:
dot = sum(x * y for x, y in zip(a, b))
norm_a = math.sqrt(sum(x * x for x in a))
norm_b = math.sqrt(sum(y * y for y in b))
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)
def wl_kernel(graph_a: GinkaGraph, graph_b: GinkaGraph, iterations: int = 3) -> float:
nodes_a = encode_node_labels(graph_a)
nodes_b = encode_node_labels(graph_b)
features_a = weisfeiler_lehman_iteration(nodes_a, iterations)
features_b = weisfeiler_lehman_iteration(nodes_b, iterations)
vocab = list(set(features_a.keys()) | set(features_b.keys()))
vec_a = vectorize_features(features_a, vocab)
vec_b = vectorize_features(features_b, vocab)
return cosine_similarity(vec_a, vec_b)
def overall_similarity(a: GinkaTopologicalGraphs, b: GinkaTopologicalGraphs) -> float:
graphs_a = a.graphs
graphs_b = b.graphs
total_similarity = 0.0
compared_graphs: Set[GinkaGraph] = set()
for ga in graphs_a:
max_similarity = 0.0
max_graph = None
for gb in graphs_b:
if gb in compared_graphs:
continue
min_nodes = min(len(ga.graph), len(gb.graph))
iterations = max(1, math.ceil(math.log(min_nodes)))
similarity = wl_kernel(ga, gb, iterations)
if similarity > max_similarity and not math.isnan(similarity):
max_similarity = similarity
max_graph = gb
if similarity == 1:
break
total_similarity += max_similarity
if max_graph:
compared_graphs.add(max_graph)
reduction = 1 / (1 + abs(len(a.unreachable) - len(b.unreachable)))
if not graphs_a:
return 0.0
return math.sqrt(total_similarity / len(graphs_a)) * reduction

View File

@ -0,0 +1,75 @@
from typing import List, Dict
import math
import numpy as np
# 视觉相似度,由 ChatGPT-4o 从 ts 转译而来
class VisualSimilarityConfig:
def __init__(self):
self.type_weights: Dict[int, float] = {
0: 0.2, 1: 0.3, 2: 0.6, 3: 0.7, 4: 0.7, 5: 0.5,
6: 0.4, 7: 0.5, 8: 0.6, 9: 0.6, 10: 0.4, 11: 0.4, 12: 0.7
}
self.enable_visual_focus: bool = True
self.enable_density_awareness: bool = True
def generate_focus_weights(rows: int, cols: int) -> List[List[float]]:
weights = []
center_x = cols / 2
center_y = rows / 2
for i in range(rows):
row_weights = []
for j in range(cols):
dx = (j - center_x) / cols
dy = (i - center_y) / rows
distance = math.sqrt(dx ** 2 + dy ** 2)
gaussian = math.exp(-(distance ** 2) / (2 * 0.3 ** 2))
row_weights.append(1.0 + 0.6 * gaussian)
weights.append(row_weights)
return weights
def calculate_density_impact(map1: List[List[int]], map2: List[List[int]], type_weights: Dict[int, float]) -> List[List[float]]:
rows, cols = len(map1), len(map1[0])
density_map = [[0.0 for _ in range(cols)] for _ in range(rows)]
window_size = 3
half_window = window_size // 2
for i in range(rows):
for j in range(cols):
density = 0
for di in range(-half_window, half_window + 1):
for dj in range(-half_window, half_window + 1):
ni, nj = i + di, j + dj
if 0 <= ni < rows and 0 <= nj < cols:
weight1 = type_weights.get(map1[ni][nj], 0.5)
weight2 = type_weights.get(map2[ni][nj], 0.5)
density += (weight1 + weight2) / 2
density_map[i][j] = 1.0 + 0.4 * (density / (window_size ** 2))
return density_map
def calculate_visual_similarity(map1: List[List[int]], map2: List[List[int]], config: VisualSimilarityConfig = None) -> float:
if config is None:
config = VisualSimilarityConfig()
if len(map1) != len(map2) or len(map1[0]) != len(map2[0]):
return 0.0
rows, cols = len(map1), len(map1[0])
total_score = 0.0
max_possible_score = 0.0
focus_weights = generate_focus_weights(rows, cols) if config.enable_visual_focus else [[1.0 for _ in range(cols)] for _ in range(rows)]
density_map = calculate_density_impact(map1, map2, config.type_weights) if config.enable_density_awareness else [[1.0 for _ in range(cols)] for _ in range(rows)]
for i in range(rows):
for j in range(cols):
type1 = map1[i][j]
type2 = map2[i][j]
base_weight = max(config.type_weights.get(type1, 0.5), config.type_weights.get(type2, 0.5))
spatial_weight = focus_weights[i][j] * density_map[i][j]
type_score = 1.0 if type1 == type2 else 0.0
total_score += type_score * base_weight * spatial_weight
max_possible_score += base_weight * spatial_weight
return total_score / max_possible_score if max_possible_score > 0 else 0.0

1
train.txt Normal file
View File

@ -0,0 +1 @@
python3 -u -m ginka.train_wgan --epochs 200 --checkpoint 20 --resume true --state_ginka result/wgan/ginka-400.pth --state_minamo result/wgan/minamo-400.pth >> output.log