diff --git a/ginka/model/loss.py b/ginka/model/loss.py index 5b004e5..0c5e894 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -7,6 +7,8 @@ from torch_geometric.data import Data from minamo.model.model import MinamoModel from shared.graph import batch_convert_soft_map_to_graph from shared.constant import VISION_WEIGHT, TOPO_WEIGHT +from shared.similarity.topo import overall_similarity, build_topological_graph +from shared.similarity.vision import calculate_visual_similarity CLASS_NUM = 32 ILLEGAL_MAX_NUM = 12 @@ -286,32 +288,21 @@ def interpolate_graph_features(real_graph, fake_graph, epsilon=0.5): return Data(x=x_interp, edge_index=edge_index_interp, edge_attr=edge_attr_interp) -def js_divergence(P, Q, epsilon=1e-10): - """ - 输入: - P, Q: [B, C, H, W], 已通过 Softmax 处理 - 输出: - JS 散度标量(全局平均) - """ - # 转换为 [B, H, W, C] 以便在最后一维计算概率分布 - P = P.permute(0, 2, 3, 1) # [B, H, W, C] - Q = Q.permute(0, 2, 3, 1) +def js_divergence(p, q, eps=1e-8): + # softmax 后变成概率分布 + m = 0.5 * (p + q) - # 平均分布 M = (P + Q)/2 - M = 0.5 * (P + Q) - - # 计算 KL(P||M) 和 KL(Q||M) - kl_pm = F.kl_div(torch.log(M + epsilon), P, reduction='none', log_target=False).sum(dim=-1) # [B, H, W] - kl_qm = F.kl_div(torch.log(M + epsilon), Q, reduction='none', log_target=False).sum(dim=-1) # [B, H, W] - - # JS 散度 = 0.5*(KL(P||M) + KL(Q||M)) - js = 0.5 * (kl_pm + kl_qm) - - # 全局平均(可替换为其他聚合方式) - return js.mean() # 标量 + # log_softmax 以供 kl_div 使用 + log_p = torch.log(p + eps) + log_q = torch.log(q + eps) + + kl_pm = F.kl_div(log_p, m, reduction='batchmean', log_target=False) # KL(p || m) + kl_qm = F.kl_div(log_q, m, reduction='batchmean', log_target=False) # KL(q || m) + + return torch.clamp(0.5 * (kl_pm + kl_qm), max=1.0) class WGANGinkaLoss: - def __init__(self, lambda_gp=50, weight=[0.7, 0.2, 0.1], diversity_lamda=0.2): + def __init__(self, lambda_gp=100, weight=[0.8, 0.1, 0.1], diversity_lamda=0.4): self.lambda_gp = lambda_gp # 梯度惩罚系数 self.weight = weight self.diversity_lamda = diversity_lamda @@ -369,8 +360,50 @@ class WGANGinkaLoss: return d_loss, d_loss + self.lambda_gp * grad_loss - def generator_loss_one(self, critic, fake): - fake_graph = batch_convert_soft_map_to_graph(fake) + def calculate_similarity_one(self, map1, map2): + topo1 = build_topological_graph(map1) + topo2 = build_topological_graph(map2) + + vis_sim = calculate_visual_similarity(map1, map2) + topo_sim = overall_similarity(topo1, topo2) + + return vis_sim, topo_sim + + def discriminator_loss_assist(self, critic, fake_data1, fake_data2): + graph1 = batch_convert_soft_map_to_graph(fake_data1) + graph2 = batch_convert_soft_map_to_graph(fake_data2) + vis_feat_1, topo_feat_1 = critic(fake_data1, graph1) + vis_feat_2, topo_feat_2 = critic(fake_data2, graph2) + + batch1 = torch.argmax(fake_data1, dim=1).cpu().tolist() + batch2 = torch.argmax(fake_data2, dim=1).cpu().tolist() + + vis_sim_real = [] + topo_sim_real = [] + + for i in range(len(batch1)): + vis_sim, topo_sim = self.calculate_similarity_one(batch1[i], batch2[i]) + vis_sim_real.append(vis_sim) + topo_sim_real.append(topo_sim) + + vis_sim_real = torch.Tensor(vis_sim_real) + topo_sim_real = torch.Tensor(topo_sim_real) + + pred_vis_sim = F.cosine_similarity(vis_feat_1, vis_feat_2).cpu() + pred_topo_sim = F.cosine_similarity(topo_feat_1, topo_feat_2).cpu() + + loss1 = F.l1_loss(pred_vis_sim, vis_sim_real) * VISION_WEIGHT + F.l1_loss(pred_topo_sim, topo_sim_real) * TOPO_WEIGHT + + return loss1 + + def discriminator_loss_assist2(self, critic, real_data, fake_data1, fake_data2): + loss1 = self.discriminator_loss_assist(critic, real_data, fake_data1) + loss2 = self.discriminator_loss_assist(critic, real_data, fake_data2) + loss3 = self.discriminator_loss_assist(critic, fake_data1, fake_data2) + + return loss1 / 3.0 + loss2 / 3.0 + loss3 / 3.0 + + def generator_loss_one(self, critic, fake, fake_graph): fake_scores, _, _ = critic(fake, fake_graph) minamo_loss = -torch.mean(fake_scores) class_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake) @@ -383,16 +416,25 @@ class WGANGinkaLoss: ] return sum(losses) - - def diversity_loss(self, fake1, fake2): - fake1 = fake1[:, :, 1:-1, 1:-1] - fake2 = fake2[:, :, 1:-1, 1:-1] - - return js_divergence(fake1, fake2) - def generator_loss(self, critic, fake1, fake2): + def generator_loss(self, critic, critic_assist, fake1, fake2): """ 生成器损失函数 """ - loss1 = self.generator_loss_one(critic, fake1) - loss2 = self.generator_loss_one(critic, fake2) + fake_graph1 = batch_convert_soft_map_to_graph(fake1) + fake_graph2 = batch_convert_soft_map_to_graph(fake2) - return loss1 * 0.5 + loss2 * 0.5 - self.diversity_lamda * self.diversity_loss(fake1, fake2) + loss1 = self.generator_loss_one(critic, fake1, fake_graph1) + loss2 = self.generator_loss_one(critic, fake2, fake_graph2) + + # vis_feat1, topo_feat1 = critic_assist(fake1, fake_graph1) + # vis_feat2, topo_feat2 = critic_assist(fake2, fake_graph2) + + # vis_sim = F.cosine_similarity(vis_feat1, vis_feat2) + # topo_sim = F.cosine_similarity(topo_feat1, topo_feat2) + # similarity = vis_sim * VISION_WEIGHT + topo_sim * TOPO_WEIGHT + + # print(similarity.mean().item()) + # div_loss = F.l1_loss(fake1[:, :, 1:-1, 1:-1], fake2[:, :, 1:-1, 1:-1]) + + return loss1 * 0.5 + loss2 * 0.5\ + # + self.diversity_lamda * F.relu(0.7 - div_loss).mean() + # + self.diversity_lamda * F.relu(similarity - 0.4).mean() diff --git a/ginka/model/model.py b/ginka/model/model.py index 515dac8..6ef6cbf 100644 --- a/ginka/model/model.py +++ b/ginka/model/model.py @@ -37,12 +37,12 @@ if __name__ == "__main__": print_memory("初始化后") # 前向传播 - output, output_softmax = model(feat) + output = model(feat) print_memory("前向传播后") print(f"输入形状: feat={feat.shape}") - print(f"输出形状: output={output.shape}, softmax={output_softmax.shape}") + print(f"输出形状: output={output.shape}") # print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}") # print(f"Feature Encoder parameters: {sum(p.numel() for p in model.feat_enc.parameters())}") print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}") diff --git a/ginka/model/unet.py b/ginka/model/unet.py index 2d52ff2..dd1d185 100644 --- a/ginka/model/unet.py +++ b/ginka/model/unet.py @@ -1,7 +1,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from shared.constant import FEAT_DIM +from torch_geometric.nn import GCNConv +from torch_geometric.utils import grid +from shared.attention import ChannelAttention class GinkaTransformerEncoder(nn.Module): def __init__(self, in_dim, hidden_dim, out_dim, token_size, ff_dim, num_heads=8, num_layers=6): @@ -33,7 +35,7 @@ class GinkaTransformerEncoder(nn.Module): return x class ConvBlock(nn.Module): - def __init__(self, in_ch, out_ch): + def __init__(self, in_ch, out_ch, atte=True): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'), @@ -41,11 +43,69 @@ class ConvBlock(nn.Module): nn.ELU(), nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'), nn.InstanceNorm2d(out_ch), - nn.ELU(), ) + if atte: + self.conv.append(ChannelAttention(out_ch)) + self.conv.append(nn.ELU()) def forward(self, x): return self.conv(x) + +class GCNBlock(nn.Module): + def __init__(self, in_ch, hidden_ch, out_ch, w, h): + super().__init__() + self.conv1 = GCNConv(in_ch, hidden_ch) + self.conv2 = GCNConv(hidden_ch, out_ch) + self.norm1 = nn.LayerNorm(hidden_ch) + self.norm2 = nn.LayerNorm(out_ch) + self.single_edge_index, _ = grid(h, w) # [2, E] for a single map + + def forward(self, x): + # x: [B, C, H, W] + B, C, H, W = x.shape + + # Reshape to [B * H * W, C] + x = x.permute(0, 2, 3, 1).reshape(B * H * W, C) + + # Construct batched edge index + device = x.device + edge_index = self._batch_edge_index(B, self.single_edge_index.to(device), H * W) + + # Batch vector for PyG (not strictly needed for GCNConv, but useful if you switch to GAT/Pooling) + # batch = torch.arange(B, device=device).repeat_interleave(H * W) + + # GCN forward + x = self.conv1(x, edge_index) + x = F.elu(self.norm1(x)) + x = self.conv2(x, edge_index) + x = F.elu(self.norm2(x)) + + # Reshape back to [B, C, H, W] + x = x.view(B, H, W, -1).permute(0, 3, 1, 2) + return x + + def _batch_edge_index(self, B, edge_index, num_nodes_per_batch): + # 批次偏移 edge_index + edge_index = edge_index.clone() # [2, E] + batch_edge_index = [] + for i in range(B): + offset = i * num_nodes_per_batch + batch_edge_index.append(edge_index + offset) + return torch.cat(batch_edge_index, dim=1) + +class FusionModule(nn.Module): + def __init__(self, in_ch, out_ch): + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_ch, out_ch, 1), + nn.InstanceNorm2d(out_ch), + nn.ELU() + ) + + def forward(self, x1, x2): + x = torch.cat([x1, x2], dim=1) + x = self.conv(x) + return x class GinkaEncoder(nn.Module): """编码器(下采样)部分""" @@ -59,6 +119,21 @@ class GinkaEncoder(nn.Module): x = self.pool(x) return x +class GinkaGCNFusedEncoder(nn.Module): + def __init__(self, in_ch, out_ch, w, h): + super().__init__() + self.conv = ConvBlock(in_ch, out_ch) + self.gcn = GCNBlock(out_ch, out_ch*2, out_ch, w, h) + self.pool = nn.MaxPool2d(2) + self.fusion = FusionModule(out_ch*2, out_ch) + + def forward(self, x): + x = self.conv(x) + x = self.pool(x) + x2 = self.gcn(x) + x = self.fusion(x, x2) + return x + class GinkaUpSample(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() @@ -83,6 +158,44 @@ class GinkaDecoder(nn.Module): x = torch.cat([x, feat], dim=1) x = self.conv(x) return x + +class GinkaGCNFusedDecoder(nn.Module): + def __init__(self, in_ch, out_ch, w, h): + super().__init__() + self.upsample = GinkaUpSample(in_ch, in_ch // 2) + self.conv = ConvBlock(in_ch, out_ch) + self.gcn = GCNBlock(out_ch, out_ch*2, out_ch, w, h) + self.fusion = FusionModule(out_ch*2, out_ch) + + def forward(self, x, feat): + x = self.upsample(x) + x = torch.cat([x, feat], dim=1) + x = self.conv(x) + x2 = self.gcn(x) + x = self.fusion(x, x2) + return x + +class GinkaBottleneck(nn.Module): + def __init__(self, module_ch, w, h): + super().__init__() + self.transformer = GinkaTransformerEncoder( + in_dim=module_ch*w*h, hidden_dim=module_ch*w*h, out_dim=module_ch*w*h, + token_size=16, ff_dim=1024, num_layers=4 + ) + self.gcn = GCNBlock(module_ch, module_ch*2, module_ch, 4, 4) + self.fusion = FusionModule(module_ch*2, module_ch) + + def forward(self, x): + B = x.size(0) + + x1 = x.view(B, 512, 16).permute(0, 2, 1) # [B, 16, in_ch] + x1 = self.transformer(x1) + x1 = x1.permute(0, 2, 1).view(B, 512, 4, 4) # [B, out_ch, 4, 4] + x2 = self.gcn(x) + + x = self.fusion(x1, x2) + + return x class GinkaUNet(nn.Module): def __init__(self, base_ch=64, out_ch=32, feat_dim=1024): @@ -94,17 +207,14 @@ class GinkaUNet(nn.Module): token_size=4, ff_dim=feat_dim*2, num_layers=4 ) self.down1 = ConvBlock(2, base_ch) - self.down2 = GinkaEncoder(base_ch, base_ch*2) - self.down3 = GinkaEncoder(base_ch*2, base_ch*4) + self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16) + self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8) self.down4 = GinkaEncoder(base_ch*4, base_ch*8) - self.bottleneck = GinkaTransformerEncoder( - in_dim=base_ch*8*4*4, hidden_dim=base_ch*8*4*4, out_dim=base_ch*8*4*4, - token_size=16, ff_dim=1024, num_layers=4 - ) + self.bottleneck = GinkaBottleneck(base_ch*8, 4, 4) - self.up1 = GinkaDecoder(base_ch*8, base_ch*4) - self.up2 = GinkaDecoder(base_ch*4, base_ch*2) - self.up3 = GinkaDecoder(base_ch*2, base_ch) + self.up1 = GinkaGCNFusedDecoder(base_ch*8, base_ch*4, 8, 8) + self.up2 = GinkaGCNFusedDecoder(base_ch*4, base_ch*2, 16, 16) + self.up3 = GinkaGCNFusedDecoder(base_ch*2, base_ch, 32, 32) self.final = nn.Sequential( nn.Conv2d(base_ch, out_ch, 1), @@ -121,9 +231,7 @@ class GinkaUNet(nn.Module): x2 = self.down2(x1) # [B, 128, 16, 16] x3 = self.down3(x2) # [B, 256, 8, 8] x4 = self.down4(x3) # [B, 512, 4, 4] - x4 = x4.view(B, 512, 16).permute(0, 2, 1) # [B, 16, 512] - x4 = self.bottleneck(x4) # [B, 16, 512] - x4 = x4.permute(0, 2, 1).view(B, 512, 4, 4) # [B, 512, 4, 4] + x4 = self.bottleneck(x4) # [B, 512, 4, 4] # 上采样 x = self.up1(x4, x3) # [B, 256, 8, 8] diff --git a/ginka/train_wgan.py b/ginka/train_wgan.py index 019cddb..1b5cbec 100644 --- a/ginka/train_wgan.py +++ b/ginka/train_wgan.py @@ -11,6 +11,7 @@ from .model.model import GinkaModel from .dataset import GinkaWGANDataset from .model.loss import WGANGinkaLoss from minamo.model.model import MinamoScoreModule +from minamo.model.similarity import MinamoSimilarityModel from shared.graph import batch_convert_soft_map_to_graph from shared.image import matrix_to_image_cv from shared.constant import VISION_WEIGHT, TOPO_WEIGHT @@ -26,10 +27,12 @@ disable_tqdm = not sys.stdout.isatty() def parse_arguments(): parser = argparse.ArgumentParser(description="training codes") parser.add_argument("--resume", type=bool, default=False) - parser.add_argument("--state_ginka", type=str, default="result/ginka.pth") - parser.add_argument("--state_minamo", type=str, default="result/minamo.pth") + parser.add_argument("--state_ginka", type=str, default="result/wgan/ginka-100.pth") + parser.add_argument("--state_minamo", type=str, default="result/wgan/minamo-100.pth") parser.add_argument("--train", type=str, default="ginka-dataset.json") parser.add_argument("--epochs", type=int, default=100) + parser.add_argument("--checkpoint", type=int, default=5) + parser.add_argument("--load_optim", type=bool, default=True) args = parser.parse_args() return args @@ -49,14 +52,20 @@ def train(): ginka = GinkaModel() minamo = MinamoScoreModule() + minamo_sim = MinamoSimilarityModel() ginka.to(device) minamo.to(device) + minamo_sim.to(device) dataset = GinkaWGANDataset(args.train, device) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9)) optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9)) + optimizer_minamo_sim = optim.Adam(minamo_sim.parameters(), lr=1e-4) + + # scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs) + # scheduler_minamo = optim.lr_scheduler.CosineAnnealingLR(optimizer_minamo, T_max=args.epochs) criterion = WGANGinkaLoss() @@ -67,14 +76,29 @@ def train(): tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED) if args.resume: - data = torch.load(args.state_ginka, map_location=device) - ginka.load_state_dict(data["model_state"], strict=False) - data = torch.load(args.state_minamo, map_location=device) - minamo.load_state_dict(data["model_state"], strict=False) + data_ginka = torch.load(args.state_ginka, map_location=device) + data_minamo = torch.load(args.state_minamo, map_location=device) + + ginka.load_state_dict(data_ginka["model_state"], strict=False) + minamo.load_state_dict(data_minamo["model_state"], strict=False) + + if data_ginka["c_steps"] is not None and data_ginka["g_steps"] is not None: + c_steps = data_ginka["c_steps"] + g_steps = data_ginka["g_steps"] + + if args.load_optim: + if data_ginka["optim_state"] is not None: + optimizer_ginka.load_state_dict(data_ginka["optim_state"]) + if data_minamo["optim_state"] is not None: + optimizer_minamo.load_state_dict(data_minamo["optim_state"]) + if data_minamo["optim_state_sim"] is not None: + optimizer_minamo_sim.load_state_dict(data_minamo["optim_state_sim"]) + print("Train from loaded state.") - for epoch in tqdm(range(args.epochs), desc="GAN Training", disable=disable_tqdm): + for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm): loss_total_minamo = torch.Tensor([0]).to(device) + loss_total_minamo_sim = torch.Tensor([0]).to(device) loss_total_ginka = torch.Tensor([0]).to(device) dis_total = torch.Tensor([0]).to(device) @@ -83,13 +107,11 @@ def train(): real_data = real_data.to(device) real_graph = batch_convert_soft_map_to_graph(real_data) - optimizer_ginka.zero_grad() - # ---------- 训练判别器 for _ in range(c_steps): # 生成假样本 optimizer_minamo.zero_grad() - z = torch.randn(batch_size, 1024, device=device) + z = torch.rand(batch_size, 1024, device=device) fake_data = ginka(z) fake_data = fake_data.detach() @@ -97,59 +119,65 @@ def train(): # 反向传播 dis, loss_d = criterion.discriminator_loss(minamo, real_data, real_graph, fake_data) loss_d.backward() - # torch.nn.utils.clip_grad_norm_(minamo.parameters(), max_norm=1.0) + # torch.nn.utils.clip_grad_norm_(minamo.parameters(), max_norm=2.0) # total_norm = torch.linalg.vector_norm(torch.stack([torch.linalg.vector_norm(p.grad) for p in minamo.topo_model.parameters()]), 2) # print("Critic 梯度范数:", total_norm.item()) # print("Critic 输入范围:", fake_data.min().item(), fake_data.max().item(), real_data.min().item(), real_data.max().item()) # print("Critic 输出范围:", d_real.min().item(), d_real.max().item()) optimizer_minamo.step() - loss_total_minamo += loss_d - dis_total += dis + loss_total_minamo += loss_d.detach() + dis_total += dis.detach() # ---------- 训练生成器 for _ in range(g_steps): + optimizer_ginka.zero_grad() + # optimizer_minamo_sim.zero_grad() + z1 = torch.randn(batch_size, 1024, device=device) z2 = torch.randn(batch_size, 1024, device=device) fake_softmax1, fake_softmax2 = ginka(z1), ginka(z2) - loss_g = criterion.generator_loss(minamo, fake_softmax1, fake_softmax2) + # 先训练辅助判别器 + # loss_c_assist = criterion.discriminator_loss_assist2(minamo_sim, real_data, fake_softmax1, fake_softmax2) + # loss_c_assist.backward(retain_graph=True) + # optimizer_minamo_sim.step() + + loss_g = criterion.generator_loss(minamo, minamo_sim, fake_softmax1, fake_softmax2) loss_g.backward() optimizer_ginka.step() loss_total_ginka += loss_g + # loss_total_minamo_sim += loss_c_assist.detach() # tqdm.write(f"{dis.item():.12f}, {loss_d.item():.12f}, {loss_g.item():.12f}") avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps + avg_loss_minamo_sim = loss_total_minamo_sim.item() / len(dataloader) / g_steps avg_dis = dis_total.item() / len(dataloader) / c_steps tqdm.write( - f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | Wasserstein Loss: {avg_dis:.8f} | Loss Ginka: {avg_loss_ginka:.8f} | Loss Minamo: {avg_loss_minamo:.8f}" + f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +\ + f"Epoch: {epoch + 1} | W Loss: {avg_dis:.8f} | " +\ + f"G Loss: {avg_loss_ginka:.8f} | D Loss: {avg_loss_minamo:.8f} | " +\ + f"lr G: {(optimizer_ginka.param_groups[0]['lr']):.8f}" ) + # scheduler_ginka.step() + # scheduler_minamo.step() + if avg_dis < 0: g_steps = max(int(-avg_dis * 5), 1) else: g_steps = 1 - # if avg_dis > 0: - # c_steps = min(max(int(avg_dis * 5), 1), 5) - # else: - # c_steps = 1 - - # if avg_loss_minamo > 0: - # c_steps += min(max(int(avg_loss_minamo * 3), 1), 5) - # else: - # c_steps += 0 - - # if avg_dis > 3: - # c_steps = 3 - # else: - # c_steps = 1 + if avg_loss_ginka > 0 or avg_loss_minamo > 0: + c_steps = min(5 + (avg_loss_ginka + avg_loss_minamo) * 5, 15) + else: + c_steps = 5 # 每若干轮输出一次图片,并保存检查点 - if (epoch + 1) % 5 == 0: + if (epoch + 1) % args.checkpoint == 0: # 输出 20 张图片,每批次 4 张,一共五批 idx = 0 with torch.no_grad(): @@ -165,18 +193,25 @@ def train(): # 保存检查点 torch.save({ - "model_state": ginka.state_dict() + "model_state": ginka.state_dict(), + "optim_state": optimizer_ginka.state_dict(), + "c_steps": c_steps, + "g_steps": g_steps }, f"result/wgan/ginka-{epoch + 1}.pth") torch.save({ - "model_state": minamo.state_dict() + "model_state": minamo.state_dict(), + "model_state_sim": minamo_sim.state_dict(), + "optim_state": optimizer_minamo.state_dict(), + "optim_state_sim": optimizer_minamo_sim.state_dict() }, f"result/wgan/minamo-{epoch + 1}.pth") print("Train ended.") torch.save({ - "model_state": ginka.state_dict() + "model_state": ginka.state_dict(), }, f"result/ginka.pth") torch.save({ - "model_state": minamo.state_dict() + "model_state": minamo.state_dict(), + "model_state_sim": minamo_sim.state_dict(), }, f"result/minamo.pth") if __name__ == "__main__": diff --git a/minamo/model/model.py b/minamo/model/model.py index 0de0350..5d245e0 100644 --- a/minamo/model/model.py +++ b/minamo/model/model.py @@ -20,40 +20,6 @@ class MinamoModel(nn.Module): return vision_feat, topo_feat -class MinamoVisionScore(nn.Module): - def __init__(self, tile_types=32): - super().__init__() - # 视觉相似度部分 - self.vision_model = MinamoVisionModel(tile_types) - # 输出层 - self.vision_fc = nn.Sequential( - nn.Linear(512, 2048), - nn.LeakyReLU(0.2), - nn.Linear(2048, 1) - ) - - def forward(self, map): - vision_feat = self.vision_model(map) - vis_score = self.vision_fc(vision_feat) - return vis_score - -class MinamoTopoScore(nn.Module): - def __init__(self, tile_types=32): - super().__init__() - # 拓扑相似度部分 - self.topo_model = MinamoTopoModel(tile_types) - # 输出层 - self.topo_fc = nn.Sequential( - nn.Linear(512, 2048), - nn.LeakyReLU(0.2), - nn.Linear(2048, 1) - ) - - def forward(self, graph): - topo_feat = self.topo_model(graph) - topo_score = self.topo_fc(topo_feat) - return topo_score - class MinamoScoreModule(nn.Module): def __init__(self, tile_types=32): super().__init__() diff --git a/minamo/model/similarity.py b/minamo/model/similarity.py new file mode 100644 index 0000000..41d3ba3 --- /dev/null +++ b/minamo/model/similarity.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.nn import GCNConv, global_mean_pool +from torch_geometric.data import Data + +class MinamoSimilarityVision(nn.Module): + def __init__(self, in_ch, out_ch): + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_ch, in_ch * 2, 3, padding=1), + nn.InstanceNorm2d(in_ch * 2), + nn.ReLU(), + + nn.Conv2d(in_ch * 2, in_ch * 4, 3, padding=1), + nn.InstanceNorm2d(in_ch * 4), + nn.ReLU(), + + nn.Conv2d(in_ch * 4, in_ch * 8, 3), + nn.InstanceNorm2d(in_ch * 8), + nn.ReLU(), + + nn.AdaptiveAvgPool2d(1) + ) + self.fc = nn.Sequential( + nn.Linear(in_ch * 8, out_ch), + ) + + def forward(self, x): + x = self.conv(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + +class MinamoSimilarityTopo(nn.Module): + def __init__(self, in_ch, hidden_dim, out_ch): + super().__init__() + self.input_fc = nn.Sequential( + nn.Linear(in_ch, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.ReLU(), + ) + + self.conv1 = GCNConv(hidden_dim, hidden_dim*2) + self.conv2 = GCNConv(hidden_dim*2, hidden_dim*4) + self.conv3 = GCNConv(hidden_dim*4, hidden_dim*8) + + self.norm1 = nn.LayerNorm(hidden_dim*2) + self.norm2 = nn.LayerNorm(hidden_dim*4) + self.norm3 = nn.LayerNorm(hidden_dim*8) + + self.output_fc = nn.Sequential( + nn.Linear(hidden_dim*8, out_ch) + ) + + def forward(self, graph: Data): + x = self.input_fc(graph.x) + + x = self.conv1(x, graph.edge_index) + x = F.relu(self.norm1(x)) + + x = self.conv2(x, graph.edge_index) + x = F.relu(self.norm2(x)) + + x = self.conv3(x, graph.edge_index) + x = F.relu(self.norm3(x)) + + x = global_mean_pool(x, graph.batch) + x = self.output_fc(x) + + return x + +class MinamoSimilarityModel(nn.Module): + def __init__(self, tile_type=32): + super().__init__() + self.vision = MinamoSimilarityVision(tile_type, 512) + self.topo = MinamoSimilarityTopo(tile_type, 64, 512) + + def forward(self, x, graph): + vis_feat = self.vision(x) + topo_feat = self.topo(graph) + return vis_feat, topo_feat + \ No newline at end of file diff --git a/minamo/model/vision.py b/minamo/model/vision.py index 50c2c45..e52a03e 100644 --- a/minamo/model/vision.py +++ b/minamo/model/vision.py @@ -7,16 +7,16 @@ class MinamoVisionModel(nn.Module): def __init__(self, in_ch=32, out_dim=512): super().__init__() self.conv = nn.Sequential( - spectral_norm(nn.Conv2d(in_ch, in_ch*2, 3, stride=2)), # 6*6 + spectral_norm(nn.Conv2d(in_ch, in_ch*2, 3)), # 11*11 nn.LeakyReLU(0.2), - spectral_norm(nn.Conv2d(in_ch*2, in_ch*4, 3)), #4*4 + spectral_norm(nn.Conv2d(in_ch*2, in_ch*4, 3)), #9*9 nn.LeakyReLU(0.2), - spectral_norm(nn.Conv2d(in_ch*4, in_ch*8, 3)), # 2*2 + spectral_norm(nn.Conv2d(in_ch*4, in_ch*8, 3)), # 7*7 nn.LeakyReLU(0.2), - nn.Flatten() + nn.AdaptiveAvgPool2d(2) ) self.fc = nn.Sequential( spectral_norm(nn.Linear(in_ch*8*2*2, out_dim)), @@ -25,5 +25,6 @@ class MinamoVisionModel(nn.Module): def forward(self, x): x = self.conv(x) + x = x.view(x.size(0), -1) x = self.fc(x) return x diff --git a/shared/attention.py b/shared/attention.py index c77cc69..c4e7c14 100644 --- a/shared/attention.py +++ b/shared/attention.py @@ -9,7 +9,7 @@ class ChannelAttention(nn.Module): self.channel_att = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(channels, channels//reduction, 1), - nn.GELU(), + nn.ELU(), nn.Conv2d(channels//reduction, channels, 1), nn.Sigmoid() ) diff --git a/shared/constant.py b/shared/constant.py index ed07a65..d194a6c 100644 --- a/shared/constant.py +++ b/shared/constant.py @@ -2,5 +2,5 @@ VIS_DIM = 512 TOPO_DIM = 512 FEAT_DIM = 1024 -VISION_WEIGHT = 0.3 -TOPO_WEIGHT = 0.7 +VISION_WEIGHT = 0.2 +TOPO_WEIGHT = 0.8 diff --git a/shared/similarity/topo.py b/shared/similarity/topo.py new file mode 100644 index 0000000..3fe80ce --- /dev/null +++ b/shared/similarity/topo.py @@ -0,0 +1,273 @@ +# Converted Python version of the JS code +import math +from typing import Dict, Set, List, Tuple, Union +from collections import deque, defaultdict + +# 拓扑相似度,由 ChatGPT-4o 从 ts 转译而来 + +class ResourceArea: + def __init__(self): + self.type = 'resource' + self.resources: Dict[int, int] = {} + self.members: Set[int] = set() + self.neighbor: Set[int] = set() + +class BranchNode: + def __init__(self, tile: int): + self.type = 'branch' + self.tile = tile + self.neighbor: Set[int] = set() + +class ResourceNode: + def __init__(self, resource_type: int, area: ResourceArea): + self.type = 'resource' + self.resourceType = resource_type + self.neighbor = area.neighbor + self.resourceArea = area + +GinkaNode = Union[BranchNode, ResourceNode] + +class GinkaGraph: + def __init__(self): + self.graph: Dict[int, GinkaNode] = {} + self.resourceMap: Dict[int, int] = {} + self.areaMap: List[ResourceArea] = [] + self.visitedEntrance: Set[int] = set() + self.visited: Set[int] = set() + +class GinkaTopologicalGraphs: + def __init__(self): + self.graphs: List[GinkaGraph] = [] + self.entranceMap: Dict[int, GinkaGraph] = {} + self.unreachable: Set[int] = set() + +TILE_TYPE = set(range(13)) +BRANCH_TYPE = {6, 7, 8, 9} +ENTRANCE_TYPE = {10, 11} +RESOURCE_TYPE = {0, 2, 3, 4, 5, 10, 11, 12} + +directions: List[Tuple[int, int]] = [ + (-1, 0), (1, 0), (0, -1), (0, 1) +] + +def find_resource_nodes(map_: List[List[int]]): + width, height = len(map_[0]), len(map_) + visited = set() + areas = [] + resource_map = {} + + for ny in range(height): + for nx in range(width): + tile = map_[ny][nx] + index = ny * width + nx + if index in visited or tile not in RESOURCE_TYPE: + continue + queue = deque([(nx, ny)]) + area = ResourceArea() + area.resources[tile] = 1 + area.members.add(index) + while queue: + cx, cy = queue.popleft() + cindex = cy * width + cx + if cindex in visited: + continue + ctile = map_[cy][cx] + if ctile not in RESOURCE_TYPE: + continue + visited.add(cindex) + area.resources[ctile] = area.resources.get(ctile, 0) + 1 + area.members.add(cindex) + resource_map[cindex] = len(areas) + for dx, dy in directions: + px, py = cx + dx, cy + dy + if 0 <= px < width and 0 <= py < height: + queue.append((px, py)) + areas.append(area) + return areas, resource_map + +def build_graph_from_entrance(map_: List[List[int]], entrance: int, resource_map: Dict[int, int], area_map: List[ResourceArea]) -> GinkaGraph: + width, height = len(map_[0]), len(map_) + graph = GinkaGraph() + graph.resourceMap = resource_map + graph.areaMap = area_map + + visited = graph.visited + visited_entrance = graph.visitedEntrance + visited_entrance.add(entrance) + + branch_nodes = set() + queue = deque([(entrance % width, entrance // width)]) + + while queue: + x, y = queue.popleft() + index = y * width + x + if index in visited: + continue + tile = map_[y][x] + if tile in ENTRANCE_TYPE: + visited_entrance.add(index) + if tile in BRANCH_TYPE: + branch_nodes.add(index) + visited.add(index) + for dx, dy in directions: + px, py = x + dx, y + dy + if 0 <= px < width and 0 <= py < height and map_[py][px] != 1: + queue.append((px, py)) + + for v in branch_nodes: + x, y = v % width, v // width + if v not in graph.graph: + graph.graph[v] = BranchNode(map_[y][x]) + node = graph.graph[v] + for dx, dy in directions: + px, py = x + dx, y + dy + if 0 <= px < width and 0 <= py < height: + index = py * width + px + if index in branch_nodes: + node.neighbor.add(index) + elif index in resource_map: + area = area_map[resource_map[index]] + area.neighbor.add(v) + for m in area.members: + node.neighbor.add(m) + + for area in area_map: + for index in area.members: + x, y = index % width, index // width + tile = map_[y][x] + if tile == 0: + continue + node = ResourceNode(tile, area) + graph.graph[index] = node + + return graph + +def build_topological_graph(map_: List[List[int]]) -> GinkaTopologicalGraphs: + width, height = len(map_[0]), len(map_) + entrances = set() + entrances = {y * width + x for y in range(height) for x in range(width) if map_[y][x] in ENTRANCE_TYPE} + area_map, resource_map = find_resource_nodes(map_) + + top_graph = GinkaTopologicalGraphs() + used_entrance = set() + total_visited = set() + + for entrance in entrances: + if entrance in used_entrance: + continue + graph = build_graph_from_entrance(map_, entrance, resource_map, area_map) + top_graph.graphs.append(graph) + for ent in graph.visitedEntrance: + used_entrance.add(ent) + top_graph.entranceMap[ent] = graph + total_visited.update(graph.visited) + + for y in range(height): + for x in range(width): + index = y * width + x + if index not in total_visited and map_[y][x] != 1: + top_graph.unreachable.add(index) + + return top_graph + +class WLNode: + def __init__(self, pos: int, label: str): + self.originalPos = pos + self.originalLabel = label + self.currentLabel = label + self.neighbors: List['WLNode'] = [] + +def encode_node_labels(graph: GinkaGraph) -> List[WLNode]: + node_map = {} + nodes = [] + for pos, node in graph.graph.items(): + label = f"B:{node.tile}" if node.type == 'branch' else f"R:{node.resourceType}" + wl_node = WLNode(pos, label) + node_map[pos] = wl_node + nodes.append(wl_node) + + for node in nodes: + g_node = graph.graph[node.originalPos] + for neighbor in g_node.neighbor: + if neighbor in node_map: + node.neighbors.append(node_map[neighbor]) + + return nodes + +def weisfeiler_lehman_iteration(nodes: List[WLNode], iterations: int, decay: float = 0.6) -> Dict[str, float]: + label_history = [] + for _ in range(iterations): + new_labels = [] + for node in nodes: + neighbor_labels = sorted(n.currentLabel for n in node.neighbors) + composite = f"{node.currentLabel}|{','.join(neighbor_labels)}"[:8192] + new_labels.append(composite) + for node, new_label in zip(nodes, new_labels): + node.currentLabel = new_label + label_history.append(new_labels[:]) + + weight = 1.0 + label_counts = defaultdict(float) + for layer in label_history: + for label in layer: + label_counts[label] += weight + weight *= decay + for node in nodes: + label_counts[node.originalLabel] += weight + return dict(label_counts) + +def vectorize_features(features: Dict[str, float], vocab: List[str]) -> List[float]: + vec = [0.0] * len(vocab) + for label, count in features.items(): + if label in vocab: + idx = vocab.index(label) + vec[idx] += count + return vec + +def cosine_similarity(a: List[float], b: List[float]) -> float: + dot = sum(x * y for x, y in zip(a, b)) + norm_a = math.sqrt(sum(x * x for x in a)) + norm_b = math.sqrt(sum(y * y for y in b)) + if norm_a == 0 or norm_b == 0: + return 0.0 + return dot / (norm_a * norm_b) + +def wl_kernel(graph_a: GinkaGraph, graph_b: GinkaGraph, iterations: int = 3) -> float: + nodes_a = encode_node_labels(graph_a) + nodes_b = encode_node_labels(graph_b) + features_a = weisfeiler_lehman_iteration(nodes_a, iterations) + features_b = weisfeiler_lehman_iteration(nodes_b, iterations) + vocab = list(set(features_a.keys()) | set(features_b.keys())) + vec_a = vectorize_features(features_a, vocab) + vec_b = vectorize_features(features_b, vocab) + return cosine_similarity(vec_a, vec_b) + +def overall_similarity(a: GinkaTopologicalGraphs, b: GinkaTopologicalGraphs) -> float: + graphs_a = a.graphs + graphs_b = b.graphs + + total_similarity = 0.0 + compared_graphs: Set[GinkaGraph] = set() + + for ga in graphs_a: + max_similarity = 0.0 + max_graph = None + for gb in graphs_b: + if gb in compared_graphs: + continue + min_nodes = min(len(ga.graph), len(gb.graph)) + iterations = max(1, math.ceil(math.log(min_nodes))) + similarity = wl_kernel(ga, gb, iterations) + if similarity > max_similarity and not math.isnan(similarity): + max_similarity = similarity + max_graph = gb + if similarity == 1: + break + total_similarity += max_similarity + if max_graph: + compared_graphs.add(max_graph) + + reduction = 1 / (1 + abs(len(a.unreachable) - len(b.unreachable))) + if not graphs_a: + return 0.0 + return math.sqrt(total_similarity / len(graphs_a)) * reduction diff --git a/shared/similarity/vision.py b/shared/similarity/vision.py new file mode 100644 index 0000000..09347d0 --- /dev/null +++ b/shared/similarity/vision.py @@ -0,0 +1,75 @@ +from typing import List, Dict +import math +import numpy as np + +# 视觉相似度,由 ChatGPT-4o 从 ts 转译而来 + +class VisualSimilarityConfig: + def __init__(self): + self.type_weights: Dict[int, float] = { + 0: 0.2, 1: 0.3, 2: 0.6, 3: 0.7, 4: 0.7, 5: 0.5, + 6: 0.4, 7: 0.5, 8: 0.6, 9: 0.6, 10: 0.4, 11: 0.4, 12: 0.7 + } + self.enable_visual_focus: bool = True + self.enable_density_awareness: bool = True + +def generate_focus_weights(rows: int, cols: int) -> List[List[float]]: + weights = [] + center_x = cols / 2 + center_y = rows / 2 + for i in range(rows): + row_weights = [] + for j in range(cols): + dx = (j - center_x) / cols + dy = (i - center_y) / rows + distance = math.sqrt(dx ** 2 + dy ** 2) + gaussian = math.exp(-(distance ** 2) / (2 * 0.3 ** 2)) + row_weights.append(1.0 + 0.6 * gaussian) + weights.append(row_weights) + return weights + +def calculate_density_impact(map1: List[List[int]], map2: List[List[int]], type_weights: Dict[int, float]) -> List[List[float]]: + rows, cols = len(map1), len(map1[0]) + density_map = [[0.0 for _ in range(cols)] for _ in range(rows)] + window_size = 3 + half_window = window_size // 2 + + for i in range(rows): + for j in range(cols): + density = 0 + for di in range(-half_window, half_window + 1): + for dj in range(-half_window, half_window + 1): + ni, nj = i + di, j + dj + if 0 <= ni < rows and 0 <= nj < cols: + weight1 = type_weights.get(map1[ni][nj], 0.5) + weight2 = type_weights.get(map2[ni][nj], 0.5) + density += (weight1 + weight2) / 2 + density_map[i][j] = 1.0 + 0.4 * (density / (window_size ** 2)) + return density_map + +def calculate_visual_similarity(map1: List[List[int]], map2: List[List[int]], config: VisualSimilarityConfig = None) -> float: + if config is None: + config = VisualSimilarityConfig() + + if len(map1) != len(map2) or len(map1[0]) != len(map2[0]): + return 0.0 + + rows, cols = len(map1), len(map1[0]) + total_score = 0.0 + max_possible_score = 0.0 + + focus_weights = generate_focus_weights(rows, cols) if config.enable_visual_focus else [[1.0 for _ in range(cols)] for _ in range(rows)] + density_map = calculate_density_impact(map1, map2, config.type_weights) if config.enable_density_awareness else [[1.0 for _ in range(cols)] for _ in range(rows)] + + for i in range(rows): + for j in range(cols): + type1 = map1[i][j] + type2 = map2[i][j] + base_weight = max(config.type_weights.get(type1, 0.5), config.type_weights.get(type2, 0.5)) + spatial_weight = focus_weights[i][j] * density_map[i][j] + type_score = 1.0 if type1 == type2 else 0.0 + + total_score += type_score * base_weight * spatial_weight + max_possible_score += base_weight * spatial_weight + + return total_score / max_possible_score if max_possible_score > 0 else 0.0 diff --git a/train.txt b/train.txt new file mode 100644 index 0000000..fba325c --- /dev/null +++ b/train.txt @@ -0,0 +1 @@ +python3 -u -m ginka.train_wgan --epochs 200 --checkpoint 20 --resume true --state_ginka result/wgan/ginka-400.pth --state_minamo result/wgan/minamo-400.pth >> output.log \ No newline at end of file