perf: 改进 WGAN 训练

This commit is contained in:
unanmed 2025-04-10 22:42:58 +08:00
parent 268b21e0b7
commit 99f46150be
12 changed files with 711 additions and 127 deletions

View File

@ -7,6 +7,8 @@ from torch_geometric.data import Data
from minamo.model.model import MinamoModel
from shared.graph import batch_convert_soft_map_to_graph
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
from shared.similarity.topo import overall_similarity, build_topological_graph
from shared.similarity.vision import calculate_visual_similarity
CLASS_NUM = 32
ILLEGAL_MAX_NUM = 12
@ -286,32 +288,21 @@ def interpolate_graph_features(real_graph, fake_graph, epsilon=0.5):
return Data(x=x_interp, edge_index=edge_index_interp, edge_attr=edge_attr_interp)
def js_divergence(P, Q, epsilon=1e-10):
"""
输入:
P, Q: [B, C, H, W], 已通过 Softmax 处理
输出:
JS 散度标量全局平均
"""
# 转换为 [B, H, W, C] 以便在最后一维计算概率分布
P = P.permute(0, 2, 3, 1) # [B, H, W, C]
Q = Q.permute(0, 2, 3, 1)
def js_divergence(p, q, eps=1e-8):
# softmax 后变成概率分布
m = 0.5 * (p + q)
# 平均分布 M = (P + Q)/2
M = 0.5 * (P + Q)
# 计算 KL(P||M) 和 KL(Q||M)
kl_pm = F.kl_div(torch.log(M + epsilon), P, reduction='none', log_target=False).sum(dim=-1) # [B, H, W]
kl_qm = F.kl_div(torch.log(M + epsilon), Q, reduction='none', log_target=False).sum(dim=-1) # [B, H, W]
# JS 散度 = 0.5*(KL(P||M) + KL(Q||M))
js = 0.5 * (kl_pm + kl_qm)
# 全局平均(可替换为其他聚合方式)
return js.mean() # 标量
# log_softmax 以供 kl_div 使用
log_p = torch.log(p + eps)
log_q = torch.log(q + eps)
kl_pm = F.kl_div(log_p, m, reduction='batchmean', log_target=False) # KL(p || m)
kl_qm = F.kl_div(log_q, m, reduction='batchmean', log_target=False) # KL(q || m)
return torch.clamp(0.5 * (kl_pm + kl_qm), max=1.0)
class WGANGinkaLoss:
def __init__(self, lambda_gp=50, weight=[0.7, 0.2, 0.1], diversity_lamda=0.2):
def __init__(self, lambda_gp=100, weight=[0.8, 0.1, 0.1], diversity_lamda=0.4):
self.lambda_gp = lambda_gp # 梯度惩罚系数
self.weight = weight
self.diversity_lamda = diversity_lamda
@ -369,8 +360,50 @@ class WGANGinkaLoss:
return d_loss, d_loss + self.lambda_gp * grad_loss
def generator_loss_one(self, critic, fake):
fake_graph = batch_convert_soft_map_to_graph(fake)
def calculate_similarity_one(self, map1, map2):
topo1 = build_topological_graph(map1)
topo2 = build_topological_graph(map2)
vis_sim = calculate_visual_similarity(map1, map2)
topo_sim = overall_similarity(topo1, topo2)
return vis_sim, topo_sim
def discriminator_loss_assist(self, critic, fake_data1, fake_data2):
graph1 = batch_convert_soft_map_to_graph(fake_data1)
graph2 = batch_convert_soft_map_to_graph(fake_data2)
vis_feat_1, topo_feat_1 = critic(fake_data1, graph1)
vis_feat_2, topo_feat_2 = critic(fake_data2, graph2)
batch1 = torch.argmax(fake_data1, dim=1).cpu().tolist()
batch2 = torch.argmax(fake_data2, dim=1).cpu().tolist()
vis_sim_real = []
topo_sim_real = []
for i in range(len(batch1)):
vis_sim, topo_sim = self.calculate_similarity_one(batch1[i], batch2[i])
vis_sim_real.append(vis_sim)
topo_sim_real.append(topo_sim)
vis_sim_real = torch.Tensor(vis_sim_real)
topo_sim_real = torch.Tensor(topo_sim_real)
pred_vis_sim = F.cosine_similarity(vis_feat_1, vis_feat_2).cpu()
pred_topo_sim = F.cosine_similarity(topo_feat_1, topo_feat_2).cpu()
loss1 = F.l1_loss(pred_vis_sim, vis_sim_real) * VISION_WEIGHT + F.l1_loss(pred_topo_sim, topo_sim_real) * TOPO_WEIGHT
return loss1
def discriminator_loss_assist2(self, critic, real_data, fake_data1, fake_data2):
loss1 = self.discriminator_loss_assist(critic, real_data, fake_data1)
loss2 = self.discriminator_loss_assist(critic, real_data, fake_data2)
loss3 = self.discriminator_loss_assist(critic, fake_data1, fake_data2)
return loss1 / 3.0 + loss2 / 3.0 + loss3 / 3.0
def generator_loss_one(self, critic, fake, fake_graph):
fake_scores, _, _ = critic(fake, fake_graph)
minamo_loss = -torch.mean(fake_scores)
class_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
@ -383,16 +416,25 @@ class WGANGinkaLoss:
]
return sum(losses)
def diversity_loss(self, fake1, fake2):
fake1 = fake1[:, :, 1:-1, 1:-1]
fake2 = fake2[:, :, 1:-1, 1:-1]
return js_divergence(fake1, fake2)
def generator_loss(self, critic, fake1, fake2):
def generator_loss(self, critic, critic_assist, fake1, fake2):
""" 生成器损失函数 """
loss1 = self.generator_loss_one(critic, fake1)
loss2 = self.generator_loss_one(critic, fake2)
fake_graph1 = batch_convert_soft_map_to_graph(fake1)
fake_graph2 = batch_convert_soft_map_to_graph(fake2)
return loss1 * 0.5 + loss2 * 0.5 - self.diversity_lamda * self.diversity_loss(fake1, fake2)
loss1 = self.generator_loss_one(critic, fake1, fake_graph1)
loss2 = self.generator_loss_one(critic, fake2, fake_graph2)
# vis_feat1, topo_feat1 = critic_assist(fake1, fake_graph1)
# vis_feat2, topo_feat2 = critic_assist(fake2, fake_graph2)
# vis_sim = F.cosine_similarity(vis_feat1, vis_feat2)
# topo_sim = F.cosine_similarity(topo_feat1, topo_feat2)
# similarity = vis_sim * VISION_WEIGHT + topo_sim * TOPO_WEIGHT
# print(similarity.mean().item())
# div_loss = F.l1_loss(fake1[:, :, 1:-1, 1:-1], fake2[:, :, 1:-1, 1:-1])
return loss1 * 0.5 + loss2 * 0.5\
# + self.diversity_lamda * F.relu(0.7 - div_loss).mean()
# + self.diversity_lamda * F.relu(similarity - 0.4).mean()

View File

@ -37,12 +37,12 @@ if __name__ == "__main__":
print_memory("初始化后")
# 前向传播
output, output_softmax = model(feat)
output = model(feat)
print_memory("前向传播后")
print(f"输入形状: feat={feat.shape}")
print(f"输出形状: output={output.shape}, softmax={output_softmax.shape}")
print(f"输出形状: output={output.shape}")
# print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}")
# print(f"Feature Encoder parameters: {sum(p.numel() for p in model.feat_enc.parameters())}")
print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}")

View File

@ -1,7 +1,9 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from shared.constant import FEAT_DIM
from torch_geometric.nn import GCNConv
from torch_geometric.utils import grid
from shared.attention import ChannelAttention
class GinkaTransformerEncoder(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim, token_size, ff_dim, num_heads=8, num_layers=6):
@ -33,7 +35,7 @@ class GinkaTransformerEncoder(nn.Module):
return x
class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch):
def __init__(self, in_ch, out_ch, atte=True):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'),
@ -41,11 +43,69 @@ class ConvBlock(nn.Module):
nn.ELU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(out_ch),
nn.ELU(),
)
if atte:
self.conv.append(ChannelAttention(out_ch))
self.conv.append(nn.ELU())
def forward(self, x):
return self.conv(x)
class GCNBlock(nn.Module):
def __init__(self, in_ch, hidden_ch, out_ch, w, h):
super().__init__()
self.conv1 = GCNConv(in_ch, hidden_ch)
self.conv2 = GCNConv(hidden_ch, out_ch)
self.norm1 = nn.LayerNorm(hidden_ch)
self.norm2 = nn.LayerNorm(out_ch)
self.single_edge_index, _ = grid(h, w) # [2, E] for a single map
def forward(self, x):
# x: [B, C, H, W]
B, C, H, W = x.shape
# Reshape to [B * H * W, C]
x = x.permute(0, 2, 3, 1).reshape(B * H * W, C)
# Construct batched edge index
device = x.device
edge_index = self._batch_edge_index(B, self.single_edge_index.to(device), H * W)
# Batch vector for PyG (not strictly needed for GCNConv, but useful if you switch to GAT/Pooling)
# batch = torch.arange(B, device=device).repeat_interleave(H * W)
# GCN forward
x = self.conv1(x, edge_index)
x = F.elu(self.norm1(x))
x = self.conv2(x, edge_index)
x = F.elu(self.norm2(x))
# Reshape back to [B, C, H, W]
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
return x
def _batch_edge_index(self, B, edge_index, num_nodes_per_batch):
# 批次偏移 edge_index
edge_index = edge_index.clone() # [2, E]
batch_edge_index = []
for i in range(B):
offset = i * num_nodes_per_batch
batch_edge_index.append(edge_index + offset)
return torch.cat(batch_edge_index, dim=1)
class FusionModule(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 1),
nn.InstanceNorm2d(out_ch),
nn.ELU()
)
def forward(self, x1, x2):
x = torch.cat([x1, x2], dim=1)
x = self.conv(x)
return x
class GinkaEncoder(nn.Module):
"""编码器(下采样)部分"""
@ -59,6 +119,21 @@ class GinkaEncoder(nn.Module):
x = self.pool(x)
return x
class GinkaGCNFusedEncoder(nn.Module):
def __init__(self, in_ch, out_ch, w, h):
super().__init__()
self.conv = ConvBlock(in_ch, out_ch)
self.gcn = GCNBlock(out_ch, out_ch*2, out_ch, w, h)
self.pool = nn.MaxPool2d(2)
self.fusion = FusionModule(out_ch*2, out_ch)
def forward(self, x):
x = self.conv(x)
x = self.pool(x)
x2 = self.gcn(x)
x = self.fusion(x, x2)
return x
class GinkaUpSample(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
@ -83,6 +158,44 @@ class GinkaDecoder(nn.Module):
x = torch.cat([x, feat], dim=1)
x = self.conv(x)
return x
class GinkaGCNFusedDecoder(nn.Module):
def __init__(self, in_ch, out_ch, w, h):
super().__init__()
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
self.conv = ConvBlock(in_ch, out_ch)
self.gcn = GCNBlock(out_ch, out_ch*2, out_ch, w, h)
self.fusion = FusionModule(out_ch*2, out_ch)
def forward(self, x, feat):
x = self.upsample(x)
x = torch.cat([x, feat], dim=1)
x = self.conv(x)
x2 = self.gcn(x)
x = self.fusion(x, x2)
return x
class GinkaBottleneck(nn.Module):
def __init__(self, module_ch, w, h):
super().__init__()
self.transformer = GinkaTransformerEncoder(
in_dim=module_ch*w*h, hidden_dim=module_ch*w*h, out_dim=module_ch*w*h,
token_size=16, ff_dim=1024, num_layers=4
)
self.gcn = GCNBlock(module_ch, module_ch*2, module_ch, 4, 4)
self.fusion = FusionModule(module_ch*2, module_ch)
def forward(self, x):
B = x.size(0)
x1 = x.view(B, 512, 16).permute(0, 2, 1) # [B, 16, in_ch]
x1 = self.transformer(x1)
x1 = x1.permute(0, 2, 1).view(B, 512, 4, 4) # [B, out_ch, 4, 4]
x2 = self.gcn(x)
x = self.fusion(x1, x2)
return x
class GinkaUNet(nn.Module):
def __init__(self, base_ch=64, out_ch=32, feat_dim=1024):
@ -94,17 +207,14 @@ class GinkaUNet(nn.Module):
token_size=4, ff_dim=feat_dim*2, num_layers=4
)
self.down1 = ConvBlock(2, base_ch)
self.down2 = GinkaEncoder(base_ch, base_ch*2)
self.down3 = GinkaEncoder(base_ch*2, base_ch*4)
self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16)
self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8)
self.down4 = GinkaEncoder(base_ch*4, base_ch*8)
self.bottleneck = GinkaTransformerEncoder(
in_dim=base_ch*8*4*4, hidden_dim=base_ch*8*4*4, out_dim=base_ch*8*4*4,
token_size=16, ff_dim=1024, num_layers=4
)
self.bottleneck = GinkaBottleneck(base_ch*8, 4, 4)
self.up1 = GinkaDecoder(base_ch*8, base_ch*4)
self.up2 = GinkaDecoder(base_ch*4, base_ch*2)
self.up3 = GinkaDecoder(base_ch*2, base_ch)
self.up1 = GinkaGCNFusedDecoder(base_ch*8, base_ch*4, 8, 8)
self.up2 = GinkaGCNFusedDecoder(base_ch*4, base_ch*2, 16, 16)
self.up3 = GinkaGCNFusedDecoder(base_ch*2, base_ch, 32, 32)
self.final = nn.Sequential(
nn.Conv2d(base_ch, out_ch, 1),
@ -121,9 +231,7 @@ class GinkaUNet(nn.Module):
x2 = self.down2(x1) # [B, 128, 16, 16]
x3 = self.down3(x2) # [B, 256, 8, 8]
x4 = self.down4(x3) # [B, 512, 4, 4]
x4 = x4.view(B, 512, 16).permute(0, 2, 1) # [B, 16, 512]
x4 = self.bottleneck(x4) # [B, 16, 512]
x4 = x4.permute(0, 2, 1).view(B, 512, 4, 4) # [B, 512, 4, 4]
x4 = self.bottleneck(x4) # [B, 512, 4, 4]
# 上采样
x = self.up1(x4, x3) # [B, 256, 8, 8]

View File

@ -11,6 +11,7 @@ from .model.model import GinkaModel
from .dataset import GinkaWGANDataset
from .model.loss import WGANGinkaLoss
from minamo.model.model import MinamoScoreModule
from minamo.model.similarity import MinamoSimilarityModel
from shared.graph import batch_convert_soft_map_to_graph
from shared.image import matrix_to_image_cv
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
@ -26,10 +27,12 @@ disable_tqdm = not sys.stdout.isatty()
def parse_arguments():
parser = argparse.ArgumentParser(description="training codes")
parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--state_ginka", type=str, default="result/ginka.pth")
parser.add_argument("--state_minamo", type=str, default="result/minamo.pth")
parser.add_argument("--state_ginka", type=str, default="result/wgan/ginka-100.pth")
parser.add_argument("--state_minamo", type=str, default="result/wgan/minamo-100.pth")
parser.add_argument("--train", type=str, default="ginka-dataset.json")
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--checkpoint", type=int, default=5)
parser.add_argument("--load_optim", type=bool, default=True)
args = parser.parse_args()
return args
@ -49,14 +52,20 @@ def train():
ginka = GinkaModel()
minamo = MinamoScoreModule()
minamo_sim = MinamoSimilarityModel()
ginka.to(device)
minamo.to(device)
minamo_sim.to(device)
dataset = GinkaWGANDataset(args.train, device)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9))
optimizer_minamo_sim = optim.Adam(minamo_sim.parameters(), lr=1e-4)
# scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs)
# scheduler_minamo = optim.lr_scheduler.CosineAnnealingLR(optimizer_minamo, T_max=args.epochs)
criterion = WGANGinkaLoss()
@ -67,14 +76,29 @@ def train():
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
if args.resume:
data = torch.load(args.state_ginka, map_location=device)
ginka.load_state_dict(data["model_state"], strict=False)
data = torch.load(args.state_minamo, map_location=device)
minamo.load_state_dict(data["model_state"], strict=False)
data_ginka = torch.load(args.state_ginka, map_location=device)
data_minamo = torch.load(args.state_minamo, map_location=device)
ginka.load_state_dict(data_ginka["model_state"], strict=False)
minamo.load_state_dict(data_minamo["model_state"], strict=False)
if data_ginka["c_steps"] is not None and data_ginka["g_steps"] is not None:
c_steps = data_ginka["c_steps"]
g_steps = data_ginka["g_steps"]
if args.load_optim:
if data_ginka["optim_state"] is not None:
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
if data_minamo["optim_state"] is not None:
optimizer_minamo.load_state_dict(data_minamo["optim_state"])
if data_minamo["optim_state_sim"] is not None:
optimizer_minamo_sim.load_state_dict(data_minamo["optim_state_sim"])
print("Train from loaded state.")
for epoch in tqdm(range(args.epochs), desc="GAN Training", disable=disable_tqdm):
for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm):
loss_total_minamo = torch.Tensor([0]).to(device)
loss_total_minamo_sim = torch.Tensor([0]).to(device)
loss_total_ginka = torch.Tensor([0]).to(device)
dis_total = torch.Tensor([0]).to(device)
@ -83,13 +107,11 @@ def train():
real_data = real_data.to(device)
real_graph = batch_convert_soft_map_to_graph(real_data)
optimizer_ginka.zero_grad()
# ---------- 训练判别器
for _ in range(c_steps):
# 生成假样本
optimizer_minamo.zero_grad()
z = torch.randn(batch_size, 1024, device=device)
z = torch.rand(batch_size, 1024, device=device)
fake_data = ginka(z)
fake_data = fake_data.detach()
@ -97,59 +119,65 @@ def train():
# 反向传播
dis, loss_d = criterion.discriminator_loss(minamo, real_data, real_graph, fake_data)
loss_d.backward()
# torch.nn.utils.clip_grad_norm_(minamo.parameters(), max_norm=1.0)
# torch.nn.utils.clip_grad_norm_(minamo.parameters(), max_norm=2.0)
# total_norm = torch.linalg.vector_norm(torch.stack([torch.linalg.vector_norm(p.grad) for p in minamo.topo_model.parameters()]), 2)
# print("Critic 梯度范数:", total_norm.item())
# print("Critic 输入范围:", fake_data.min().item(), fake_data.max().item(), real_data.min().item(), real_data.max().item())
# print("Critic 输出范围:", d_real.min().item(), d_real.max().item())
optimizer_minamo.step()
loss_total_minamo += loss_d
dis_total += dis
loss_total_minamo += loss_d.detach()
dis_total += dis.detach()
# ---------- 训练生成器
for _ in range(g_steps):
optimizer_ginka.zero_grad()
# optimizer_minamo_sim.zero_grad()
z1 = torch.randn(batch_size, 1024, device=device)
z2 = torch.randn(batch_size, 1024, device=device)
fake_softmax1, fake_softmax2 = ginka(z1), ginka(z2)
loss_g = criterion.generator_loss(minamo, fake_softmax1, fake_softmax2)
# 先训练辅助判别器
# loss_c_assist = criterion.discriminator_loss_assist2(minamo_sim, real_data, fake_softmax1, fake_softmax2)
# loss_c_assist.backward(retain_graph=True)
# optimizer_minamo_sim.step()
loss_g = criterion.generator_loss(minamo, minamo_sim, fake_softmax1, fake_softmax2)
loss_g.backward()
optimizer_ginka.step()
loss_total_ginka += loss_g
# loss_total_minamo_sim += loss_c_assist.detach()
# tqdm.write(f"{dis.item():.12f}, {loss_d.item():.12f}, {loss_g.item():.12f}")
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
avg_loss_minamo_sim = loss_total_minamo_sim.item() / len(dataloader) / g_steps
avg_dis = dis_total.item() / len(dataloader) / c_steps
tqdm.write(
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | Wasserstein Loss: {avg_dis:.8f} | Loss Ginka: {avg_loss_ginka:.8f} | Loss Minamo: {avg_loss_minamo:.8f}"
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +\
f"Epoch: {epoch + 1} | W Loss: {avg_dis:.8f} | " +\
f"G Loss: {avg_loss_ginka:.8f} | D Loss: {avg_loss_minamo:.8f} | " +\
f"lr G: {(optimizer_ginka.param_groups[0]['lr']):.8f}"
)
# scheduler_ginka.step()
# scheduler_minamo.step()
if avg_dis < 0:
g_steps = max(int(-avg_dis * 5), 1)
else:
g_steps = 1
# if avg_dis > 0:
# c_steps = min(max(int(avg_dis * 5), 1), 5)
# else:
# c_steps = 1
# if avg_loss_minamo > 0:
# c_steps += min(max(int(avg_loss_minamo * 3), 1), 5)
# else:
# c_steps += 0
# if avg_dis > 3:
# c_steps = 3
# else:
# c_steps = 1
if avg_loss_ginka > 0 or avg_loss_minamo > 0:
c_steps = min(5 + (avg_loss_ginka + avg_loss_minamo) * 5, 15)
else:
c_steps = 5
# 每若干轮输出一次图片,并保存检查点
if (epoch + 1) % 5 == 0:
if (epoch + 1) % args.checkpoint == 0:
# 输出 20 张图片,每批次 4 张,一共五批
idx = 0
with torch.no_grad():
@ -165,18 +193,25 @@ def train():
# 保存检查点
torch.save({
"model_state": ginka.state_dict()
"model_state": ginka.state_dict(),
"optim_state": optimizer_ginka.state_dict(),
"c_steps": c_steps,
"g_steps": g_steps
}, f"result/wgan/ginka-{epoch + 1}.pth")
torch.save({
"model_state": minamo.state_dict()
"model_state": minamo.state_dict(),
"model_state_sim": minamo_sim.state_dict(),
"optim_state": optimizer_minamo.state_dict(),
"optim_state_sim": optimizer_minamo_sim.state_dict()
}, f"result/wgan/minamo-{epoch + 1}.pth")
print("Train ended.")
torch.save({
"model_state": ginka.state_dict()
"model_state": ginka.state_dict(),
}, f"result/ginka.pth")
torch.save({
"model_state": minamo.state_dict()
"model_state": minamo.state_dict(),
"model_state_sim": minamo_sim.state_dict(),
}, f"result/minamo.pth")
if __name__ == "__main__":

View File

@ -20,40 +20,6 @@ class MinamoModel(nn.Module):
return vision_feat, topo_feat
class MinamoVisionScore(nn.Module):
def __init__(self, tile_types=32):
super().__init__()
# 视觉相似度部分
self.vision_model = MinamoVisionModel(tile_types)
# 输出层
self.vision_fc = nn.Sequential(
nn.Linear(512, 2048),
nn.LeakyReLU(0.2),
nn.Linear(2048, 1)
)
def forward(self, map):
vision_feat = self.vision_model(map)
vis_score = self.vision_fc(vision_feat)
return vis_score
class MinamoTopoScore(nn.Module):
def __init__(self, tile_types=32):
super().__init__()
# 拓扑相似度部分
self.topo_model = MinamoTopoModel(tile_types)
# 输出层
self.topo_fc = nn.Sequential(
nn.Linear(512, 2048),
nn.LeakyReLU(0.2),
nn.Linear(2048, 1)
)
def forward(self, graph):
topo_feat = self.topo_model(graph)
topo_score = self.topo_fc(topo_feat)
return topo_score
class MinamoScoreModule(nn.Module):
def __init__(self, tile_types=32):
super().__init__()

View File

@ -0,0 +1,83 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data
class MinamoSimilarityVision(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, in_ch * 2, 3, padding=1),
nn.InstanceNorm2d(in_ch * 2),
nn.ReLU(),
nn.Conv2d(in_ch * 2, in_ch * 4, 3, padding=1),
nn.InstanceNorm2d(in_ch * 4),
nn.ReLU(),
nn.Conv2d(in_ch * 4, in_ch * 8, 3),
nn.InstanceNorm2d(in_ch * 8),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1)
)
self.fc = nn.Sequential(
nn.Linear(in_ch * 8, out_ch),
)
def forward(self, x):
x = self.conv(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
class MinamoSimilarityTopo(nn.Module):
def __init__(self, in_ch, hidden_dim, out_ch):
super().__init__()
self.input_fc = nn.Sequential(
nn.Linear(in_ch, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
)
self.conv1 = GCNConv(hidden_dim, hidden_dim*2)
self.conv2 = GCNConv(hidden_dim*2, hidden_dim*4)
self.conv3 = GCNConv(hidden_dim*4, hidden_dim*8)
self.norm1 = nn.LayerNorm(hidden_dim*2)
self.norm2 = nn.LayerNorm(hidden_dim*4)
self.norm3 = nn.LayerNorm(hidden_dim*8)
self.output_fc = nn.Sequential(
nn.Linear(hidden_dim*8, out_ch)
)
def forward(self, graph: Data):
x = self.input_fc(graph.x)
x = self.conv1(x, graph.edge_index)
x = F.relu(self.norm1(x))
x = self.conv2(x, graph.edge_index)
x = F.relu(self.norm2(x))
x = self.conv3(x, graph.edge_index)
x = F.relu(self.norm3(x))
x = global_mean_pool(x, graph.batch)
x = self.output_fc(x)
return x
class MinamoSimilarityModel(nn.Module):
def __init__(self, tile_type=32):
super().__init__()
self.vision = MinamoSimilarityVision(tile_type, 512)
self.topo = MinamoSimilarityTopo(tile_type, 64, 512)
def forward(self, x, graph):
vis_feat = self.vision(x)
topo_feat = self.topo(graph)
return vis_feat, topo_feat

View File

@ -7,16 +7,16 @@ class MinamoVisionModel(nn.Module):
def __init__(self, in_ch=32, out_dim=512):
super().__init__()
self.conv = nn.Sequential(
spectral_norm(nn.Conv2d(in_ch, in_ch*2, 3, stride=2)), # 6*6
spectral_norm(nn.Conv2d(in_ch, in_ch*2, 3)), # 11*11
nn.LeakyReLU(0.2),
spectral_norm(nn.Conv2d(in_ch*2, in_ch*4, 3)), #4*4
spectral_norm(nn.Conv2d(in_ch*2, in_ch*4, 3)), #9*9
nn.LeakyReLU(0.2),
spectral_norm(nn.Conv2d(in_ch*4, in_ch*8, 3)), # 2*2
spectral_norm(nn.Conv2d(in_ch*4, in_ch*8, 3)), # 7*7
nn.LeakyReLU(0.2),
nn.Flatten()
nn.AdaptiveAvgPool2d(2)
)
self.fc = nn.Sequential(
spectral_norm(nn.Linear(in_ch*8*2*2, out_dim)),
@ -25,5 +25,6 @@ class MinamoVisionModel(nn.Module):
def forward(self, x):
x = self.conv(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x

View File

@ -9,7 +9,7 @@ class ChannelAttention(nn.Module):
self.channel_att = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, channels//reduction, 1),
nn.GELU(),
nn.ELU(),
nn.Conv2d(channels//reduction, channels, 1),
nn.Sigmoid()
)

View File

@ -2,5 +2,5 @@ VIS_DIM = 512
TOPO_DIM = 512
FEAT_DIM = 1024
VISION_WEIGHT = 0.3
TOPO_WEIGHT = 0.7
VISION_WEIGHT = 0.2
TOPO_WEIGHT = 0.8

273
shared/similarity/topo.py Normal file
View File

@ -0,0 +1,273 @@
# Converted Python version of the JS code
import math
from typing import Dict, Set, List, Tuple, Union
from collections import deque, defaultdict
# 拓扑相似度,由 ChatGPT-4o 从 ts 转译而来
class ResourceArea:
def __init__(self):
self.type = 'resource'
self.resources: Dict[int, int] = {}
self.members: Set[int] = set()
self.neighbor: Set[int] = set()
class BranchNode:
def __init__(self, tile: int):
self.type = 'branch'
self.tile = tile
self.neighbor: Set[int] = set()
class ResourceNode:
def __init__(self, resource_type: int, area: ResourceArea):
self.type = 'resource'
self.resourceType = resource_type
self.neighbor = area.neighbor
self.resourceArea = area
GinkaNode = Union[BranchNode, ResourceNode]
class GinkaGraph:
def __init__(self):
self.graph: Dict[int, GinkaNode] = {}
self.resourceMap: Dict[int, int] = {}
self.areaMap: List[ResourceArea] = []
self.visitedEntrance: Set[int] = set()
self.visited: Set[int] = set()
class GinkaTopologicalGraphs:
def __init__(self):
self.graphs: List[GinkaGraph] = []
self.entranceMap: Dict[int, GinkaGraph] = {}
self.unreachable: Set[int] = set()
TILE_TYPE = set(range(13))
BRANCH_TYPE = {6, 7, 8, 9}
ENTRANCE_TYPE = {10, 11}
RESOURCE_TYPE = {0, 2, 3, 4, 5, 10, 11, 12}
directions: List[Tuple[int, int]] = [
(-1, 0), (1, 0), (0, -1), (0, 1)
]
def find_resource_nodes(map_: List[List[int]]):
width, height = len(map_[0]), len(map_)
visited = set()
areas = []
resource_map = {}
for ny in range(height):
for nx in range(width):
tile = map_[ny][nx]
index = ny * width + nx
if index in visited or tile not in RESOURCE_TYPE:
continue
queue = deque([(nx, ny)])
area = ResourceArea()
area.resources[tile] = 1
area.members.add(index)
while queue:
cx, cy = queue.popleft()
cindex = cy * width + cx
if cindex in visited:
continue
ctile = map_[cy][cx]
if ctile not in RESOURCE_TYPE:
continue
visited.add(cindex)
area.resources[ctile] = area.resources.get(ctile, 0) + 1
area.members.add(cindex)
resource_map[cindex] = len(areas)
for dx, dy in directions:
px, py = cx + dx, cy + dy
if 0 <= px < width and 0 <= py < height:
queue.append((px, py))
areas.append(area)
return areas, resource_map
def build_graph_from_entrance(map_: List[List[int]], entrance: int, resource_map: Dict[int, int], area_map: List[ResourceArea]) -> GinkaGraph:
width, height = len(map_[0]), len(map_)
graph = GinkaGraph()
graph.resourceMap = resource_map
graph.areaMap = area_map
visited = graph.visited
visited_entrance = graph.visitedEntrance
visited_entrance.add(entrance)
branch_nodes = set()
queue = deque([(entrance % width, entrance // width)])
while queue:
x, y = queue.popleft()
index = y * width + x
if index in visited:
continue
tile = map_[y][x]
if tile in ENTRANCE_TYPE:
visited_entrance.add(index)
if tile in BRANCH_TYPE:
branch_nodes.add(index)
visited.add(index)
for dx, dy in directions:
px, py = x + dx, y + dy
if 0 <= px < width and 0 <= py < height and map_[py][px] != 1:
queue.append((px, py))
for v in branch_nodes:
x, y = v % width, v // width
if v not in graph.graph:
graph.graph[v] = BranchNode(map_[y][x])
node = graph.graph[v]
for dx, dy in directions:
px, py = x + dx, y + dy
if 0 <= px < width and 0 <= py < height:
index = py * width + px
if index in branch_nodes:
node.neighbor.add(index)
elif index in resource_map:
area = area_map[resource_map[index]]
area.neighbor.add(v)
for m in area.members:
node.neighbor.add(m)
for area in area_map:
for index in area.members:
x, y = index % width, index // width
tile = map_[y][x]
if tile == 0:
continue
node = ResourceNode(tile, area)
graph.graph[index] = node
return graph
def build_topological_graph(map_: List[List[int]]) -> GinkaTopologicalGraphs:
width, height = len(map_[0]), len(map_)
entrances = set()
entrances = {y * width + x for y in range(height) for x in range(width) if map_[y][x] in ENTRANCE_TYPE}
area_map, resource_map = find_resource_nodes(map_)
top_graph = GinkaTopologicalGraphs()
used_entrance = set()
total_visited = set()
for entrance in entrances:
if entrance in used_entrance:
continue
graph = build_graph_from_entrance(map_, entrance, resource_map, area_map)
top_graph.graphs.append(graph)
for ent in graph.visitedEntrance:
used_entrance.add(ent)
top_graph.entranceMap[ent] = graph
total_visited.update(graph.visited)
for y in range(height):
for x in range(width):
index = y * width + x
if index not in total_visited and map_[y][x] != 1:
top_graph.unreachable.add(index)
return top_graph
class WLNode:
def __init__(self, pos: int, label: str):
self.originalPos = pos
self.originalLabel = label
self.currentLabel = label
self.neighbors: List['WLNode'] = []
def encode_node_labels(graph: GinkaGraph) -> List[WLNode]:
node_map = {}
nodes = []
for pos, node in graph.graph.items():
label = f"B:{node.tile}" if node.type == 'branch' else f"R:{node.resourceType}"
wl_node = WLNode(pos, label)
node_map[pos] = wl_node
nodes.append(wl_node)
for node in nodes:
g_node = graph.graph[node.originalPos]
for neighbor in g_node.neighbor:
if neighbor in node_map:
node.neighbors.append(node_map[neighbor])
return nodes
def weisfeiler_lehman_iteration(nodes: List[WLNode], iterations: int, decay: float = 0.6) -> Dict[str, float]:
label_history = []
for _ in range(iterations):
new_labels = []
for node in nodes:
neighbor_labels = sorted(n.currentLabel for n in node.neighbors)
composite = f"{node.currentLabel}|{','.join(neighbor_labels)}"[:8192]
new_labels.append(composite)
for node, new_label in zip(nodes, new_labels):
node.currentLabel = new_label
label_history.append(new_labels[:])
weight = 1.0
label_counts = defaultdict(float)
for layer in label_history:
for label in layer:
label_counts[label] += weight
weight *= decay
for node in nodes:
label_counts[node.originalLabel] += weight
return dict(label_counts)
def vectorize_features(features: Dict[str, float], vocab: List[str]) -> List[float]:
vec = [0.0] * len(vocab)
for label, count in features.items():
if label in vocab:
idx = vocab.index(label)
vec[idx] += count
return vec
def cosine_similarity(a: List[float], b: List[float]) -> float:
dot = sum(x * y for x, y in zip(a, b))
norm_a = math.sqrt(sum(x * x for x in a))
norm_b = math.sqrt(sum(y * y for y in b))
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)
def wl_kernel(graph_a: GinkaGraph, graph_b: GinkaGraph, iterations: int = 3) -> float:
nodes_a = encode_node_labels(graph_a)
nodes_b = encode_node_labels(graph_b)
features_a = weisfeiler_lehman_iteration(nodes_a, iterations)
features_b = weisfeiler_lehman_iteration(nodes_b, iterations)
vocab = list(set(features_a.keys()) | set(features_b.keys()))
vec_a = vectorize_features(features_a, vocab)
vec_b = vectorize_features(features_b, vocab)
return cosine_similarity(vec_a, vec_b)
def overall_similarity(a: GinkaTopologicalGraphs, b: GinkaTopologicalGraphs) -> float:
graphs_a = a.graphs
graphs_b = b.graphs
total_similarity = 0.0
compared_graphs: Set[GinkaGraph] = set()
for ga in graphs_a:
max_similarity = 0.0
max_graph = None
for gb in graphs_b:
if gb in compared_graphs:
continue
min_nodes = min(len(ga.graph), len(gb.graph))
iterations = max(1, math.ceil(math.log(min_nodes)))
similarity = wl_kernel(ga, gb, iterations)
if similarity > max_similarity and not math.isnan(similarity):
max_similarity = similarity
max_graph = gb
if similarity == 1:
break
total_similarity += max_similarity
if max_graph:
compared_graphs.add(max_graph)
reduction = 1 / (1 + abs(len(a.unreachable) - len(b.unreachable)))
if not graphs_a:
return 0.0
return math.sqrt(total_similarity / len(graphs_a)) * reduction

View File

@ -0,0 +1,75 @@
from typing import List, Dict
import math
import numpy as np
# 视觉相似度,由 ChatGPT-4o 从 ts 转译而来
class VisualSimilarityConfig:
def __init__(self):
self.type_weights: Dict[int, float] = {
0: 0.2, 1: 0.3, 2: 0.6, 3: 0.7, 4: 0.7, 5: 0.5,
6: 0.4, 7: 0.5, 8: 0.6, 9: 0.6, 10: 0.4, 11: 0.4, 12: 0.7
}
self.enable_visual_focus: bool = True
self.enable_density_awareness: bool = True
def generate_focus_weights(rows: int, cols: int) -> List[List[float]]:
weights = []
center_x = cols / 2
center_y = rows / 2
for i in range(rows):
row_weights = []
for j in range(cols):
dx = (j - center_x) / cols
dy = (i - center_y) / rows
distance = math.sqrt(dx ** 2 + dy ** 2)
gaussian = math.exp(-(distance ** 2) / (2 * 0.3 ** 2))
row_weights.append(1.0 + 0.6 * gaussian)
weights.append(row_weights)
return weights
def calculate_density_impact(map1: List[List[int]], map2: List[List[int]], type_weights: Dict[int, float]) -> List[List[float]]:
rows, cols = len(map1), len(map1[0])
density_map = [[0.0 for _ in range(cols)] for _ in range(rows)]
window_size = 3
half_window = window_size // 2
for i in range(rows):
for j in range(cols):
density = 0
for di in range(-half_window, half_window + 1):
for dj in range(-half_window, half_window + 1):
ni, nj = i + di, j + dj
if 0 <= ni < rows and 0 <= nj < cols:
weight1 = type_weights.get(map1[ni][nj], 0.5)
weight2 = type_weights.get(map2[ni][nj], 0.5)
density += (weight1 + weight2) / 2
density_map[i][j] = 1.0 + 0.4 * (density / (window_size ** 2))
return density_map
def calculate_visual_similarity(map1: List[List[int]], map2: List[List[int]], config: VisualSimilarityConfig = None) -> float:
if config is None:
config = VisualSimilarityConfig()
if len(map1) != len(map2) or len(map1[0]) != len(map2[0]):
return 0.0
rows, cols = len(map1), len(map1[0])
total_score = 0.0
max_possible_score = 0.0
focus_weights = generate_focus_weights(rows, cols) if config.enable_visual_focus else [[1.0 for _ in range(cols)] for _ in range(rows)]
density_map = calculate_density_impact(map1, map2, config.type_weights) if config.enable_density_awareness else [[1.0 for _ in range(cols)] for _ in range(rows)]
for i in range(rows):
for j in range(cols):
type1 = map1[i][j]
type2 = map2[i][j]
base_weight = max(config.type_weights.get(type1, 0.5), config.type_weights.get(type2, 0.5))
spatial_weight = focus_weights[i][j] * density_map[i][j]
type_score = 1.0 if type1 == type2 else 0.0
total_score += type_score * base_weight * spatial_weight
max_possible_score += base_weight * spatial_weight
return total_score / max_possible_score if max_possible_score > 0 else 0.0

1
train.txt Normal file
View File

@ -0,0 +1 @@
python3 -u -m ginka.train_wgan --epochs 200 --checkpoint 20 --resume true --state_ginka result/wgan/ginka-400.pth --state_minamo result/wgan/minamo-400.pth >> output.log