mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
perf: 改进 WGAN 训练
This commit is contained in:
parent
268b21e0b7
commit
99f46150be
@ -7,6 +7,8 @@ from torch_geometric.data import Data
|
||||
from minamo.model.model import MinamoModel
|
||||
from shared.graph import batch_convert_soft_map_to_graph
|
||||
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
|
||||
from shared.similarity.topo import overall_similarity, build_topological_graph
|
||||
from shared.similarity.vision import calculate_visual_similarity
|
||||
|
||||
CLASS_NUM = 32
|
||||
ILLEGAL_MAX_NUM = 12
|
||||
@ -286,32 +288,21 @@ def interpolate_graph_features(real_graph, fake_graph, epsilon=0.5):
|
||||
|
||||
return Data(x=x_interp, edge_index=edge_index_interp, edge_attr=edge_attr_interp)
|
||||
|
||||
def js_divergence(P, Q, epsilon=1e-10):
|
||||
"""
|
||||
输入:
|
||||
P, Q: [B, C, H, W], 已通过 Softmax 处理
|
||||
输出:
|
||||
JS 散度标量(全局平均)
|
||||
"""
|
||||
# 转换为 [B, H, W, C] 以便在最后一维计算概率分布
|
||||
P = P.permute(0, 2, 3, 1) # [B, H, W, C]
|
||||
Q = Q.permute(0, 2, 3, 1)
|
||||
def js_divergence(p, q, eps=1e-8):
|
||||
# softmax 后变成概率分布
|
||||
m = 0.5 * (p + q)
|
||||
|
||||
# 平均分布 M = (P + Q)/2
|
||||
M = 0.5 * (P + Q)
|
||||
|
||||
# 计算 KL(P||M) 和 KL(Q||M)
|
||||
kl_pm = F.kl_div(torch.log(M + epsilon), P, reduction='none', log_target=False).sum(dim=-1) # [B, H, W]
|
||||
kl_qm = F.kl_div(torch.log(M + epsilon), Q, reduction='none', log_target=False).sum(dim=-1) # [B, H, W]
|
||||
|
||||
# JS 散度 = 0.5*(KL(P||M) + KL(Q||M))
|
||||
js = 0.5 * (kl_pm + kl_qm)
|
||||
|
||||
# 全局平均(可替换为其他聚合方式)
|
||||
return js.mean() # 标量
|
||||
# log_softmax 以供 kl_div 使用
|
||||
log_p = torch.log(p + eps)
|
||||
log_q = torch.log(q + eps)
|
||||
|
||||
kl_pm = F.kl_div(log_p, m, reduction='batchmean', log_target=False) # KL(p || m)
|
||||
kl_qm = F.kl_div(log_q, m, reduction='batchmean', log_target=False) # KL(q || m)
|
||||
|
||||
return torch.clamp(0.5 * (kl_pm + kl_qm), max=1.0)
|
||||
|
||||
class WGANGinkaLoss:
|
||||
def __init__(self, lambda_gp=50, weight=[0.7, 0.2, 0.1], diversity_lamda=0.2):
|
||||
def __init__(self, lambda_gp=100, weight=[0.8, 0.1, 0.1], diversity_lamda=0.4):
|
||||
self.lambda_gp = lambda_gp # 梯度惩罚系数
|
||||
self.weight = weight
|
||||
self.diversity_lamda = diversity_lamda
|
||||
@ -369,8 +360,50 @@ class WGANGinkaLoss:
|
||||
|
||||
return d_loss, d_loss + self.lambda_gp * grad_loss
|
||||
|
||||
def generator_loss_one(self, critic, fake):
|
||||
fake_graph = batch_convert_soft_map_to_graph(fake)
|
||||
def calculate_similarity_one(self, map1, map2):
|
||||
topo1 = build_topological_graph(map1)
|
||||
topo2 = build_topological_graph(map2)
|
||||
|
||||
vis_sim = calculate_visual_similarity(map1, map2)
|
||||
topo_sim = overall_similarity(topo1, topo2)
|
||||
|
||||
return vis_sim, topo_sim
|
||||
|
||||
def discriminator_loss_assist(self, critic, fake_data1, fake_data2):
|
||||
graph1 = batch_convert_soft_map_to_graph(fake_data1)
|
||||
graph2 = batch_convert_soft_map_to_graph(fake_data2)
|
||||
vis_feat_1, topo_feat_1 = critic(fake_data1, graph1)
|
||||
vis_feat_2, topo_feat_2 = critic(fake_data2, graph2)
|
||||
|
||||
batch1 = torch.argmax(fake_data1, dim=1).cpu().tolist()
|
||||
batch2 = torch.argmax(fake_data2, dim=1).cpu().tolist()
|
||||
|
||||
vis_sim_real = []
|
||||
topo_sim_real = []
|
||||
|
||||
for i in range(len(batch1)):
|
||||
vis_sim, topo_sim = self.calculate_similarity_one(batch1[i], batch2[i])
|
||||
vis_sim_real.append(vis_sim)
|
||||
topo_sim_real.append(topo_sim)
|
||||
|
||||
vis_sim_real = torch.Tensor(vis_sim_real)
|
||||
topo_sim_real = torch.Tensor(topo_sim_real)
|
||||
|
||||
pred_vis_sim = F.cosine_similarity(vis_feat_1, vis_feat_2).cpu()
|
||||
pred_topo_sim = F.cosine_similarity(topo_feat_1, topo_feat_2).cpu()
|
||||
|
||||
loss1 = F.l1_loss(pred_vis_sim, vis_sim_real) * VISION_WEIGHT + F.l1_loss(pred_topo_sim, topo_sim_real) * TOPO_WEIGHT
|
||||
|
||||
return loss1
|
||||
|
||||
def discriminator_loss_assist2(self, critic, real_data, fake_data1, fake_data2):
|
||||
loss1 = self.discriminator_loss_assist(critic, real_data, fake_data1)
|
||||
loss2 = self.discriminator_loss_assist(critic, real_data, fake_data2)
|
||||
loss3 = self.discriminator_loss_assist(critic, fake_data1, fake_data2)
|
||||
|
||||
return loss1 / 3.0 + loss2 / 3.0 + loss3 / 3.0
|
||||
|
||||
def generator_loss_one(self, critic, fake, fake_graph):
|
||||
fake_scores, _, _ = critic(fake, fake_graph)
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
class_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
|
||||
@ -383,16 +416,25 @@ class WGANGinkaLoss:
|
||||
]
|
||||
|
||||
return sum(losses)
|
||||
|
||||
def diversity_loss(self, fake1, fake2):
|
||||
fake1 = fake1[:, :, 1:-1, 1:-1]
|
||||
fake2 = fake2[:, :, 1:-1, 1:-1]
|
||||
|
||||
return js_divergence(fake1, fake2)
|
||||
|
||||
def generator_loss(self, critic, fake1, fake2):
|
||||
def generator_loss(self, critic, critic_assist, fake1, fake2):
|
||||
""" 生成器损失函数 """
|
||||
loss1 = self.generator_loss_one(critic, fake1)
|
||||
loss2 = self.generator_loss_one(critic, fake2)
|
||||
fake_graph1 = batch_convert_soft_map_to_graph(fake1)
|
||||
fake_graph2 = batch_convert_soft_map_to_graph(fake2)
|
||||
|
||||
return loss1 * 0.5 + loss2 * 0.5 - self.diversity_lamda * self.diversity_loss(fake1, fake2)
|
||||
loss1 = self.generator_loss_one(critic, fake1, fake_graph1)
|
||||
loss2 = self.generator_loss_one(critic, fake2, fake_graph2)
|
||||
|
||||
# vis_feat1, topo_feat1 = critic_assist(fake1, fake_graph1)
|
||||
# vis_feat2, topo_feat2 = critic_assist(fake2, fake_graph2)
|
||||
|
||||
# vis_sim = F.cosine_similarity(vis_feat1, vis_feat2)
|
||||
# topo_sim = F.cosine_similarity(topo_feat1, topo_feat2)
|
||||
# similarity = vis_sim * VISION_WEIGHT + topo_sim * TOPO_WEIGHT
|
||||
|
||||
# print(similarity.mean().item())
|
||||
# div_loss = F.l1_loss(fake1[:, :, 1:-1, 1:-1], fake2[:, :, 1:-1, 1:-1])
|
||||
|
||||
return loss1 * 0.5 + loss2 * 0.5\
|
||||
# + self.diversity_lamda * F.relu(0.7 - div_loss).mean()
|
||||
# + self.diversity_lamda * F.relu(similarity - 0.4).mean()
|
||||
|
||||
@ -37,12 +37,12 @@ if __name__ == "__main__":
|
||||
print_memory("初始化后")
|
||||
|
||||
# 前向传播
|
||||
output, output_softmax = model(feat)
|
||||
output = model(feat)
|
||||
|
||||
print_memory("前向传播后")
|
||||
|
||||
print(f"输入形状: feat={feat.shape}")
|
||||
print(f"输出形状: output={output.shape}, softmax={output_softmax.shape}")
|
||||
print(f"输出形状: output={output.shape}")
|
||||
# print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}")
|
||||
# print(f"Feature Encoder parameters: {sum(p.numel() for p in model.feat_enc.parameters())}")
|
||||
print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}")
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from shared.constant import FEAT_DIM
|
||||
from torch_geometric.nn import GCNConv
|
||||
from torch_geometric.utils import grid
|
||||
from shared.attention import ChannelAttention
|
||||
|
||||
class GinkaTransformerEncoder(nn.Module):
|
||||
def __init__(self, in_dim, hidden_dim, out_dim, token_size, ff_dim, num_heads=8, num_layers=6):
|
||||
@ -33,7 +35,7 @@ class GinkaTransformerEncoder(nn.Module):
|
||||
return x
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
def __init__(self, in_ch, out_ch, atte=True):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'),
|
||||
@ -41,11 +43,69 @@ class ConvBlock(nn.Module):
|
||||
nn.ELU(),
|
||||
nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(out_ch),
|
||||
nn.ELU(),
|
||||
)
|
||||
if atte:
|
||||
self.conv.append(ChannelAttention(out_ch))
|
||||
self.conv.append(nn.ELU())
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
class GCNBlock(nn.Module):
|
||||
def __init__(self, in_ch, hidden_ch, out_ch, w, h):
|
||||
super().__init__()
|
||||
self.conv1 = GCNConv(in_ch, hidden_ch)
|
||||
self.conv2 = GCNConv(hidden_ch, out_ch)
|
||||
self.norm1 = nn.LayerNorm(hidden_ch)
|
||||
self.norm2 = nn.LayerNorm(out_ch)
|
||||
self.single_edge_index, _ = grid(h, w) # [2, E] for a single map
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B, C, H, W]
|
||||
B, C, H, W = x.shape
|
||||
|
||||
# Reshape to [B * H * W, C]
|
||||
x = x.permute(0, 2, 3, 1).reshape(B * H * W, C)
|
||||
|
||||
# Construct batched edge index
|
||||
device = x.device
|
||||
edge_index = self._batch_edge_index(B, self.single_edge_index.to(device), H * W)
|
||||
|
||||
# Batch vector for PyG (not strictly needed for GCNConv, but useful if you switch to GAT/Pooling)
|
||||
# batch = torch.arange(B, device=device).repeat_interleave(H * W)
|
||||
|
||||
# GCN forward
|
||||
x = self.conv1(x, edge_index)
|
||||
x = F.elu(self.norm1(x))
|
||||
x = self.conv2(x, edge_index)
|
||||
x = F.elu(self.norm2(x))
|
||||
|
||||
# Reshape back to [B, C, H, W]
|
||||
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
||||
def _batch_edge_index(self, B, edge_index, num_nodes_per_batch):
|
||||
# 批次偏移 edge_index
|
||||
edge_index = edge_index.clone() # [2, E]
|
||||
batch_edge_index = []
|
||||
for i in range(B):
|
||||
offset = i * num_nodes_per_batch
|
||||
batch_edge_index.append(edge_index + offset)
|
||||
return torch.cat(batch_edge_index, dim=1)
|
||||
|
||||
class FusionModule(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, 1),
|
||||
nn.InstanceNorm2d(out_ch),
|
||||
nn.ELU()
|
||||
)
|
||||
|
||||
def forward(self, x1, x2):
|
||||
x = torch.cat([x1, x2], dim=1)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
class GinkaEncoder(nn.Module):
|
||||
"""编码器(下采样)部分"""
|
||||
@ -59,6 +119,21 @@ class GinkaEncoder(nn.Module):
|
||||
x = self.pool(x)
|
||||
return x
|
||||
|
||||
class GinkaGCNFusedEncoder(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, w, h):
|
||||
super().__init__()
|
||||
self.conv = ConvBlock(in_ch, out_ch)
|
||||
self.gcn = GCNBlock(out_ch, out_ch*2, out_ch, w, h)
|
||||
self.pool = nn.MaxPool2d(2)
|
||||
self.fusion = FusionModule(out_ch*2, out_ch)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.pool(x)
|
||||
x2 = self.gcn(x)
|
||||
x = self.fusion(x, x2)
|
||||
return x
|
||||
|
||||
class GinkaUpSample(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
@ -83,6 +158,44 @@ class GinkaDecoder(nn.Module):
|
||||
x = torch.cat([x, feat], dim=1)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
class GinkaGCNFusedDecoder(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, w, h):
|
||||
super().__init__()
|
||||
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
|
||||
self.conv = ConvBlock(in_ch, out_ch)
|
||||
self.gcn = GCNBlock(out_ch, out_ch*2, out_ch, w, h)
|
||||
self.fusion = FusionModule(out_ch*2, out_ch)
|
||||
|
||||
def forward(self, x, feat):
|
||||
x = self.upsample(x)
|
||||
x = torch.cat([x, feat], dim=1)
|
||||
x = self.conv(x)
|
||||
x2 = self.gcn(x)
|
||||
x = self.fusion(x, x2)
|
||||
return x
|
||||
|
||||
class GinkaBottleneck(nn.Module):
|
||||
def __init__(self, module_ch, w, h):
|
||||
super().__init__()
|
||||
self.transformer = GinkaTransformerEncoder(
|
||||
in_dim=module_ch*w*h, hidden_dim=module_ch*w*h, out_dim=module_ch*w*h,
|
||||
token_size=16, ff_dim=1024, num_layers=4
|
||||
)
|
||||
self.gcn = GCNBlock(module_ch, module_ch*2, module_ch, 4, 4)
|
||||
self.fusion = FusionModule(module_ch*2, module_ch)
|
||||
|
||||
def forward(self, x):
|
||||
B = x.size(0)
|
||||
|
||||
x1 = x.view(B, 512, 16).permute(0, 2, 1) # [B, 16, in_ch]
|
||||
x1 = self.transformer(x1)
|
||||
x1 = x1.permute(0, 2, 1).view(B, 512, 4, 4) # [B, out_ch, 4, 4]
|
||||
x2 = self.gcn(x)
|
||||
|
||||
x = self.fusion(x1, x2)
|
||||
|
||||
return x
|
||||
|
||||
class GinkaUNet(nn.Module):
|
||||
def __init__(self, base_ch=64, out_ch=32, feat_dim=1024):
|
||||
@ -94,17 +207,14 @@ class GinkaUNet(nn.Module):
|
||||
token_size=4, ff_dim=feat_dim*2, num_layers=4
|
||||
)
|
||||
self.down1 = ConvBlock(2, base_ch)
|
||||
self.down2 = GinkaEncoder(base_ch, base_ch*2)
|
||||
self.down3 = GinkaEncoder(base_ch*2, base_ch*4)
|
||||
self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16)
|
||||
self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8)
|
||||
self.down4 = GinkaEncoder(base_ch*4, base_ch*8)
|
||||
self.bottleneck = GinkaTransformerEncoder(
|
||||
in_dim=base_ch*8*4*4, hidden_dim=base_ch*8*4*4, out_dim=base_ch*8*4*4,
|
||||
token_size=16, ff_dim=1024, num_layers=4
|
||||
)
|
||||
self.bottleneck = GinkaBottleneck(base_ch*8, 4, 4)
|
||||
|
||||
self.up1 = GinkaDecoder(base_ch*8, base_ch*4)
|
||||
self.up2 = GinkaDecoder(base_ch*4, base_ch*2)
|
||||
self.up3 = GinkaDecoder(base_ch*2, base_ch)
|
||||
self.up1 = GinkaGCNFusedDecoder(base_ch*8, base_ch*4, 8, 8)
|
||||
self.up2 = GinkaGCNFusedDecoder(base_ch*4, base_ch*2, 16, 16)
|
||||
self.up3 = GinkaGCNFusedDecoder(base_ch*2, base_ch, 32, 32)
|
||||
|
||||
self.final = nn.Sequential(
|
||||
nn.Conv2d(base_ch, out_ch, 1),
|
||||
@ -121,9 +231,7 @@ class GinkaUNet(nn.Module):
|
||||
x2 = self.down2(x1) # [B, 128, 16, 16]
|
||||
x3 = self.down3(x2) # [B, 256, 8, 8]
|
||||
x4 = self.down4(x3) # [B, 512, 4, 4]
|
||||
x4 = x4.view(B, 512, 16).permute(0, 2, 1) # [B, 16, 512]
|
||||
x4 = self.bottleneck(x4) # [B, 16, 512]
|
||||
x4 = x4.permute(0, 2, 1).view(B, 512, 4, 4) # [B, 512, 4, 4]
|
||||
x4 = self.bottleneck(x4) # [B, 512, 4, 4]
|
||||
|
||||
# 上采样
|
||||
x = self.up1(x4, x3) # [B, 256, 8, 8]
|
||||
|
||||
@ -11,6 +11,7 @@ from .model.model import GinkaModel
|
||||
from .dataset import GinkaWGANDataset
|
||||
from .model.loss import WGANGinkaLoss
|
||||
from minamo.model.model import MinamoScoreModule
|
||||
from minamo.model.similarity import MinamoSimilarityModel
|
||||
from shared.graph import batch_convert_soft_map_to_graph
|
||||
from shared.image import matrix_to_image_cv
|
||||
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
|
||||
@ -26,10 +27,12 @@ disable_tqdm = not sys.stdout.isatty()
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="training codes")
|
||||
parser.add_argument("--resume", type=bool, default=False)
|
||||
parser.add_argument("--state_ginka", type=str, default="result/ginka.pth")
|
||||
parser.add_argument("--state_minamo", type=str, default="result/minamo.pth")
|
||||
parser.add_argument("--state_ginka", type=str, default="result/wgan/ginka-100.pth")
|
||||
parser.add_argument("--state_minamo", type=str, default="result/wgan/minamo-100.pth")
|
||||
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
||||
parser.add_argument("--epochs", type=int, default=100)
|
||||
parser.add_argument("--checkpoint", type=int, default=5)
|
||||
parser.add_argument("--load_optim", type=bool, default=True)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@ -49,14 +52,20 @@ def train():
|
||||
|
||||
ginka = GinkaModel()
|
||||
minamo = MinamoScoreModule()
|
||||
minamo_sim = MinamoSimilarityModel()
|
||||
ginka.to(device)
|
||||
minamo.to(device)
|
||||
minamo_sim.to(device)
|
||||
|
||||
dataset = GinkaWGANDataset(args.train, device)
|
||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
|
||||
|
||||
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
||||
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9))
|
||||
optimizer_minamo_sim = optim.Adam(minamo_sim.parameters(), lr=1e-4)
|
||||
|
||||
# scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs)
|
||||
# scheduler_minamo = optim.lr_scheduler.CosineAnnealingLR(optimizer_minamo, T_max=args.epochs)
|
||||
|
||||
criterion = WGANGinkaLoss()
|
||||
|
||||
@ -67,14 +76,29 @@ def train():
|
||||
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
|
||||
|
||||
if args.resume:
|
||||
data = torch.load(args.state_ginka, map_location=device)
|
||||
ginka.load_state_dict(data["model_state"], strict=False)
|
||||
data = torch.load(args.state_minamo, map_location=device)
|
||||
minamo.load_state_dict(data["model_state"], strict=False)
|
||||
data_ginka = torch.load(args.state_ginka, map_location=device)
|
||||
data_minamo = torch.load(args.state_minamo, map_location=device)
|
||||
|
||||
ginka.load_state_dict(data_ginka["model_state"], strict=False)
|
||||
minamo.load_state_dict(data_minamo["model_state"], strict=False)
|
||||
|
||||
if data_ginka["c_steps"] is not None and data_ginka["g_steps"] is not None:
|
||||
c_steps = data_ginka["c_steps"]
|
||||
g_steps = data_ginka["g_steps"]
|
||||
|
||||
if args.load_optim:
|
||||
if data_ginka["optim_state"] is not None:
|
||||
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
|
||||
if data_minamo["optim_state"] is not None:
|
||||
optimizer_minamo.load_state_dict(data_minamo["optim_state"])
|
||||
if data_minamo["optim_state_sim"] is not None:
|
||||
optimizer_minamo_sim.load_state_dict(data_minamo["optim_state_sim"])
|
||||
|
||||
print("Train from loaded state.")
|
||||
|
||||
for epoch in tqdm(range(args.epochs), desc="GAN Training", disable=disable_tqdm):
|
||||
for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm):
|
||||
loss_total_minamo = torch.Tensor([0]).to(device)
|
||||
loss_total_minamo_sim = torch.Tensor([0]).to(device)
|
||||
loss_total_ginka = torch.Tensor([0]).to(device)
|
||||
dis_total = torch.Tensor([0]).to(device)
|
||||
|
||||
@ -83,13 +107,11 @@ def train():
|
||||
real_data = real_data.to(device)
|
||||
real_graph = batch_convert_soft_map_to_graph(real_data)
|
||||
|
||||
optimizer_ginka.zero_grad()
|
||||
|
||||
# ---------- 训练判别器
|
||||
for _ in range(c_steps):
|
||||
# 生成假样本
|
||||
optimizer_minamo.zero_grad()
|
||||
z = torch.randn(batch_size, 1024, device=device)
|
||||
z = torch.rand(batch_size, 1024, device=device)
|
||||
fake_data = ginka(z)
|
||||
fake_data = fake_data.detach()
|
||||
|
||||
@ -97,59 +119,65 @@ def train():
|
||||
# 反向传播
|
||||
dis, loss_d = criterion.discriminator_loss(minamo, real_data, real_graph, fake_data)
|
||||
loss_d.backward()
|
||||
# torch.nn.utils.clip_grad_norm_(minamo.parameters(), max_norm=1.0)
|
||||
# torch.nn.utils.clip_grad_norm_(minamo.parameters(), max_norm=2.0)
|
||||
# total_norm = torch.linalg.vector_norm(torch.stack([torch.linalg.vector_norm(p.grad) for p in minamo.topo_model.parameters()]), 2)
|
||||
# print("Critic 梯度范数:", total_norm.item())
|
||||
# print("Critic 输入范围:", fake_data.min().item(), fake_data.max().item(), real_data.min().item(), real_data.max().item())
|
||||
# print("Critic 输出范围:", d_real.min().item(), d_real.max().item())
|
||||
optimizer_minamo.step()
|
||||
|
||||
loss_total_minamo += loss_d
|
||||
dis_total += dis
|
||||
loss_total_minamo += loss_d.detach()
|
||||
dis_total += dis.detach()
|
||||
|
||||
# ---------- 训练生成器
|
||||
|
||||
for _ in range(g_steps):
|
||||
optimizer_ginka.zero_grad()
|
||||
# optimizer_minamo_sim.zero_grad()
|
||||
|
||||
z1 = torch.randn(batch_size, 1024, device=device)
|
||||
z2 = torch.randn(batch_size, 1024, device=device)
|
||||
fake_softmax1, fake_softmax2 = ginka(z1), ginka(z2)
|
||||
|
||||
loss_g = criterion.generator_loss(minamo, fake_softmax1, fake_softmax2)
|
||||
# 先训练辅助判别器
|
||||
# loss_c_assist = criterion.discriminator_loss_assist2(minamo_sim, real_data, fake_softmax1, fake_softmax2)
|
||||
# loss_c_assist.backward(retain_graph=True)
|
||||
# optimizer_minamo_sim.step()
|
||||
|
||||
loss_g = criterion.generator_loss(minamo, minamo_sim, fake_softmax1, fake_softmax2)
|
||||
loss_g.backward()
|
||||
optimizer_ginka.step()
|
||||
|
||||
loss_total_ginka += loss_g
|
||||
# loss_total_minamo_sim += loss_c_assist.detach()
|
||||
# tqdm.write(f"{dis.item():.12f}, {loss_d.item():.12f}, {loss_g.item():.12f}")
|
||||
|
||||
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
|
||||
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
|
||||
avg_loss_minamo_sim = loss_total_minamo_sim.item() / len(dataloader) / g_steps
|
||||
avg_dis = dis_total.item() / len(dataloader) / c_steps
|
||||
tqdm.write(
|
||||
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | Wasserstein Loss: {avg_dis:.8f} | Loss Ginka: {avg_loss_ginka:.8f} | Loss Minamo: {avg_loss_minamo:.8f}"
|
||||
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +\
|
||||
f"Epoch: {epoch + 1} | W Loss: {avg_dis:.8f} | " +\
|
||||
f"G Loss: {avg_loss_ginka:.8f} | D Loss: {avg_loss_minamo:.8f} | " +\
|
||||
f"lr G: {(optimizer_ginka.param_groups[0]['lr']):.8f}"
|
||||
)
|
||||
|
||||
# scheduler_ginka.step()
|
||||
# scheduler_minamo.step()
|
||||
|
||||
if avg_dis < 0:
|
||||
g_steps = max(int(-avg_dis * 5), 1)
|
||||
else:
|
||||
g_steps = 1
|
||||
|
||||
# if avg_dis > 0:
|
||||
# c_steps = min(max(int(avg_dis * 5), 1), 5)
|
||||
# else:
|
||||
# c_steps = 1
|
||||
|
||||
# if avg_loss_minamo > 0:
|
||||
# c_steps += min(max(int(avg_loss_minamo * 3), 1), 5)
|
||||
# else:
|
||||
# c_steps += 0
|
||||
|
||||
# if avg_dis > 3:
|
||||
# c_steps = 3
|
||||
# else:
|
||||
# c_steps = 1
|
||||
if avg_loss_ginka > 0 or avg_loss_minamo > 0:
|
||||
c_steps = min(5 + (avg_loss_ginka + avg_loss_minamo) * 5, 15)
|
||||
else:
|
||||
c_steps = 5
|
||||
|
||||
# 每若干轮输出一次图片,并保存检查点
|
||||
if (epoch + 1) % 5 == 0:
|
||||
if (epoch + 1) % args.checkpoint == 0:
|
||||
# 输出 20 张图片,每批次 4 张,一共五批
|
||||
idx = 0
|
||||
with torch.no_grad():
|
||||
@ -165,18 +193,25 @@ def train():
|
||||
|
||||
# 保存检查点
|
||||
torch.save({
|
||||
"model_state": ginka.state_dict()
|
||||
"model_state": ginka.state_dict(),
|
||||
"optim_state": optimizer_ginka.state_dict(),
|
||||
"c_steps": c_steps,
|
||||
"g_steps": g_steps
|
||||
}, f"result/wgan/ginka-{epoch + 1}.pth")
|
||||
torch.save({
|
||||
"model_state": minamo.state_dict()
|
||||
"model_state": minamo.state_dict(),
|
||||
"model_state_sim": minamo_sim.state_dict(),
|
||||
"optim_state": optimizer_minamo.state_dict(),
|
||||
"optim_state_sim": optimizer_minamo_sim.state_dict()
|
||||
}, f"result/wgan/minamo-{epoch + 1}.pth")
|
||||
|
||||
print("Train ended.")
|
||||
torch.save({
|
||||
"model_state": ginka.state_dict()
|
||||
"model_state": ginka.state_dict(),
|
||||
}, f"result/ginka.pth")
|
||||
torch.save({
|
||||
"model_state": minamo.state_dict()
|
||||
"model_state": minamo.state_dict(),
|
||||
"model_state_sim": minamo_sim.state_dict(),
|
||||
}, f"result/minamo.pth")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -20,40 +20,6 @@ class MinamoModel(nn.Module):
|
||||
|
||||
return vision_feat, topo_feat
|
||||
|
||||
class MinamoVisionScore(nn.Module):
|
||||
def __init__(self, tile_types=32):
|
||||
super().__init__()
|
||||
# 视觉相似度部分
|
||||
self.vision_model = MinamoVisionModel(tile_types)
|
||||
# 输出层
|
||||
self.vision_fc = nn.Sequential(
|
||||
nn.Linear(512, 2048),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.Linear(2048, 1)
|
||||
)
|
||||
|
||||
def forward(self, map):
|
||||
vision_feat = self.vision_model(map)
|
||||
vis_score = self.vision_fc(vision_feat)
|
||||
return vis_score
|
||||
|
||||
class MinamoTopoScore(nn.Module):
|
||||
def __init__(self, tile_types=32):
|
||||
super().__init__()
|
||||
# 拓扑相似度部分
|
||||
self.topo_model = MinamoTopoModel(tile_types)
|
||||
# 输出层
|
||||
self.topo_fc = nn.Sequential(
|
||||
nn.Linear(512, 2048),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.Linear(2048, 1)
|
||||
)
|
||||
|
||||
def forward(self, graph):
|
||||
topo_feat = self.topo_model(graph)
|
||||
topo_score = self.topo_fc(topo_feat)
|
||||
return topo_score
|
||||
|
||||
class MinamoScoreModule(nn.Module):
|
||||
def __init__(self, tile_types=32):
|
||||
super().__init__()
|
||||
|
||||
83
minamo/model/similarity.py
Normal file
83
minamo/model/similarity.py
Normal file
@ -0,0 +1,83 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.nn import GCNConv, global_mean_pool
|
||||
from torch_geometric.data import Data
|
||||
|
||||
class MinamoSimilarityVision(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_ch, in_ch * 2, 3, padding=1),
|
||||
nn.InstanceNorm2d(in_ch * 2),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.Conv2d(in_ch * 2, in_ch * 4, 3, padding=1),
|
||||
nn.InstanceNorm2d(in_ch * 4),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.Conv2d(in_ch * 4, in_ch * 8, 3),
|
||||
nn.InstanceNorm2d(in_ch * 8),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.AdaptiveAvgPool2d(1)
|
||||
)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(in_ch * 8, out_ch),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
class MinamoSimilarityTopo(nn.Module):
|
||||
def __init__(self, in_ch, hidden_dim, out_ch):
|
||||
super().__init__()
|
||||
self.input_fc = nn.Sequential(
|
||||
nn.Linear(in_ch, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.conv1 = GCNConv(hidden_dim, hidden_dim*2)
|
||||
self.conv2 = GCNConv(hidden_dim*2, hidden_dim*4)
|
||||
self.conv3 = GCNConv(hidden_dim*4, hidden_dim*8)
|
||||
|
||||
self.norm1 = nn.LayerNorm(hidden_dim*2)
|
||||
self.norm2 = nn.LayerNorm(hidden_dim*4)
|
||||
self.norm3 = nn.LayerNorm(hidden_dim*8)
|
||||
|
||||
self.output_fc = nn.Sequential(
|
||||
nn.Linear(hidden_dim*8, out_ch)
|
||||
)
|
||||
|
||||
def forward(self, graph: Data):
|
||||
x = self.input_fc(graph.x)
|
||||
|
||||
x = self.conv1(x, graph.edge_index)
|
||||
x = F.relu(self.norm1(x))
|
||||
|
||||
x = self.conv2(x, graph.edge_index)
|
||||
x = F.relu(self.norm2(x))
|
||||
|
||||
x = self.conv3(x, graph.edge_index)
|
||||
x = F.relu(self.norm3(x))
|
||||
|
||||
x = global_mean_pool(x, graph.batch)
|
||||
x = self.output_fc(x)
|
||||
|
||||
return x
|
||||
|
||||
class MinamoSimilarityModel(nn.Module):
|
||||
def __init__(self, tile_type=32):
|
||||
super().__init__()
|
||||
self.vision = MinamoSimilarityVision(tile_type, 512)
|
||||
self.topo = MinamoSimilarityTopo(tile_type, 64, 512)
|
||||
|
||||
def forward(self, x, graph):
|
||||
vis_feat = self.vision(x)
|
||||
topo_feat = self.topo(graph)
|
||||
return vis_feat, topo_feat
|
||||
|
||||
@ -7,16 +7,16 @@ class MinamoVisionModel(nn.Module):
|
||||
def __init__(self, in_ch=32, out_dim=512):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
spectral_norm(nn.Conv2d(in_ch, in_ch*2, 3, stride=2)), # 6*6
|
||||
spectral_norm(nn.Conv2d(in_ch, in_ch*2, 3)), # 11*11
|
||||
nn.LeakyReLU(0.2),
|
||||
|
||||
spectral_norm(nn.Conv2d(in_ch*2, in_ch*4, 3)), #4*4
|
||||
spectral_norm(nn.Conv2d(in_ch*2, in_ch*4, 3)), #9*9
|
||||
nn.LeakyReLU(0.2),
|
||||
|
||||
spectral_norm(nn.Conv2d(in_ch*4, in_ch*8, 3)), # 2*2
|
||||
spectral_norm(nn.Conv2d(in_ch*4, in_ch*8, 3)), # 7*7
|
||||
nn.LeakyReLU(0.2),
|
||||
|
||||
nn.Flatten()
|
||||
nn.AdaptiveAvgPool2d(2)
|
||||
)
|
||||
self.fc = nn.Sequential(
|
||||
spectral_norm(nn.Linear(in_ch*8*2*2, out_dim)),
|
||||
@ -25,5 +25,6 @@ class MinamoVisionModel(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
@ -9,7 +9,7 @@ class ChannelAttention(nn.Module):
|
||||
self.channel_att = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Conv2d(channels, channels//reduction, 1),
|
||||
nn.GELU(),
|
||||
nn.ELU(),
|
||||
nn.Conv2d(channels//reduction, channels, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
@ -2,5 +2,5 @@ VIS_DIM = 512
|
||||
TOPO_DIM = 512
|
||||
FEAT_DIM = 1024
|
||||
|
||||
VISION_WEIGHT = 0.3
|
||||
TOPO_WEIGHT = 0.7
|
||||
VISION_WEIGHT = 0.2
|
||||
TOPO_WEIGHT = 0.8
|
||||
|
||||
273
shared/similarity/topo.py
Normal file
273
shared/similarity/topo.py
Normal file
@ -0,0 +1,273 @@
|
||||
# Converted Python version of the JS code
|
||||
import math
|
||||
from typing import Dict, Set, List, Tuple, Union
|
||||
from collections import deque, defaultdict
|
||||
|
||||
# 拓扑相似度,由 ChatGPT-4o 从 ts 转译而来
|
||||
|
||||
class ResourceArea:
|
||||
def __init__(self):
|
||||
self.type = 'resource'
|
||||
self.resources: Dict[int, int] = {}
|
||||
self.members: Set[int] = set()
|
||||
self.neighbor: Set[int] = set()
|
||||
|
||||
class BranchNode:
|
||||
def __init__(self, tile: int):
|
||||
self.type = 'branch'
|
||||
self.tile = tile
|
||||
self.neighbor: Set[int] = set()
|
||||
|
||||
class ResourceNode:
|
||||
def __init__(self, resource_type: int, area: ResourceArea):
|
||||
self.type = 'resource'
|
||||
self.resourceType = resource_type
|
||||
self.neighbor = area.neighbor
|
||||
self.resourceArea = area
|
||||
|
||||
GinkaNode = Union[BranchNode, ResourceNode]
|
||||
|
||||
class GinkaGraph:
|
||||
def __init__(self):
|
||||
self.graph: Dict[int, GinkaNode] = {}
|
||||
self.resourceMap: Dict[int, int] = {}
|
||||
self.areaMap: List[ResourceArea] = []
|
||||
self.visitedEntrance: Set[int] = set()
|
||||
self.visited: Set[int] = set()
|
||||
|
||||
class GinkaTopologicalGraphs:
|
||||
def __init__(self):
|
||||
self.graphs: List[GinkaGraph] = []
|
||||
self.entranceMap: Dict[int, GinkaGraph] = {}
|
||||
self.unreachable: Set[int] = set()
|
||||
|
||||
TILE_TYPE = set(range(13))
|
||||
BRANCH_TYPE = {6, 7, 8, 9}
|
||||
ENTRANCE_TYPE = {10, 11}
|
||||
RESOURCE_TYPE = {0, 2, 3, 4, 5, 10, 11, 12}
|
||||
|
||||
directions: List[Tuple[int, int]] = [
|
||||
(-1, 0), (1, 0), (0, -1), (0, 1)
|
||||
]
|
||||
|
||||
def find_resource_nodes(map_: List[List[int]]):
|
||||
width, height = len(map_[0]), len(map_)
|
||||
visited = set()
|
||||
areas = []
|
||||
resource_map = {}
|
||||
|
||||
for ny in range(height):
|
||||
for nx in range(width):
|
||||
tile = map_[ny][nx]
|
||||
index = ny * width + nx
|
||||
if index in visited or tile not in RESOURCE_TYPE:
|
||||
continue
|
||||
queue = deque([(nx, ny)])
|
||||
area = ResourceArea()
|
||||
area.resources[tile] = 1
|
||||
area.members.add(index)
|
||||
while queue:
|
||||
cx, cy = queue.popleft()
|
||||
cindex = cy * width + cx
|
||||
if cindex in visited:
|
||||
continue
|
||||
ctile = map_[cy][cx]
|
||||
if ctile not in RESOURCE_TYPE:
|
||||
continue
|
||||
visited.add(cindex)
|
||||
area.resources[ctile] = area.resources.get(ctile, 0) + 1
|
||||
area.members.add(cindex)
|
||||
resource_map[cindex] = len(areas)
|
||||
for dx, dy in directions:
|
||||
px, py = cx + dx, cy + dy
|
||||
if 0 <= px < width and 0 <= py < height:
|
||||
queue.append((px, py))
|
||||
areas.append(area)
|
||||
return areas, resource_map
|
||||
|
||||
def build_graph_from_entrance(map_: List[List[int]], entrance: int, resource_map: Dict[int, int], area_map: List[ResourceArea]) -> GinkaGraph:
|
||||
width, height = len(map_[0]), len(map_)
|
||||
graph = GinkaGraph()
|
||||
graph.resourceMap = resource_map
|
||||
graph.areaMap = area_map
|
||||
|
||||
visited = graph.visited
|
||||
visited_entrance = graph.visitedEntrance
|
||||
visited_entrance.add(entrance)
|
||||
|
||||
branch_nodes = set()
|
||||
queue = deque([(entrance % width, entrance // width)])
|
||||
|
||||
while queue:
|
||||
x, y = queue.popleft()
|
||||
index = y * width + x
|
||||
if index in visited:
|
||||
continue
|
||||
tile = map_[y][x]
|
||||
if tile in ENTRANCE_TYPE:
|
||||
visited_entrance.add(index)
|
||||
if tile in BRANCH_TYPE:
|
||||
branch_nodes.add(index)
|
||||
visited.add(index)
|
||||
for dx, dy in directions:
|
||||
px, py = x + dx, y + dy
|
||||
if 0 <= px < width and 0 <= py < height and map_[py][px] != 1:
|
||||
queue.append((px, py))
|
||||
|
||||
for v in branch_nodes:
|
||||
x, y = v % width, v // width
|
||||
if v not in graph.graph:
|
||||
graph.graph[v] = BranchNode(map_[y][x])
|
||||
node = graph.graph[v]
|
||||
for dx, dy in directions:
|
||||
px, py = x + dx, y + dy
|
||||
if 0 <= px < width and 0 <= py < height:
|
||||
index = py * width + px
|
||||
if index in branch_nodes:
|
||||
node.neighbor.add(index)
|
||||
elif index in resource_map:
|
||||
area = area_map[resource_map[index]]
|
||||
area.neighbor.add(v)
|
||||
for m in area.members:
|
||||
node.neighbor.add(m)
|
||||
|
||||
for area in area_map:
|
||||
for index in area.members:
|
||||
x, y = index % width, index // width
|
||||
tile = map_[y][x]
|
||||
if tile == 0:
|
||||
continue
|
||||
node = ResourceNode(tile, area)
|
||||
graph.graph[index] = node
|
||||
|
||||
return graph
|
||||
|
||||
def build_topological_graph(map_: List[List[int]]) -> GinkaTopologicalGraphs:
|
||||
width, height = len(map_[0]), len(map_)
|
||||
entrances = set()
|
||||
entrances = {y * width + x for y in range(height) for x in range(width) if map_[y][x] in ENTRANCE_TYPE}
|
||||
area_map, resource_map = find_resource_nodes(map_)
|
||||
|
||||
top_graph = GinkaTopologicalGraphs()
|
||||
used_entrance = set()
|
||||
total_visited = set()
|
||||
|
||||
for entrance in entrances:
|
||||
if entrance in used_entrance:
|
||||
continue
|
||||
graph = build_graph_from_entrance(map_, entrance, resource_map, area_map)
|
||||
top_graph.graphs.append(graph)
|
||||
for ent in graph.visitedEntrance:
|
||||
used_entrance.add(ent)
|
||||
top_graph.entranceMap[ent] = graph
|
||||
total_visited.update(graph.visited)
|
||||
|
||||
for y in range(height):
|
||||
for x in range(width):
|
||||
index = y * width + x
|
||||
if index not in total_visited and map_[y][x] != 1:
|
||||
top_graph.unreachable.add(index)
|
||||
|
||||
return top_graph
|
||||
|
||||
class WLNode:
|
||||
def __init__(self, pos: int, label: str):
|
||||
self.originalPos = pos
|
||||
self.originalLabel = label
|
||||
self.currentLabel = label
|
||||
self.neighbors: List['WLNode'] = []
|
||||
|
||||
def encode_node_labels(graph: GinkaGraph) -> List[WLNode]:
|
||||
node_map = {}
|
||||
nodes = []
|
||||
for pos, node in graph.graph.items():
|
||||
label = f"B:{node.tile}" if node.type == 'branch' else f"R:{node.resourceType}"
|
||||
wl_node = WLNode(pos, label)
|
||||
node_map[pos] = wl_node
|
||||
nodes.append(wl_node)
|
||||
|
||||
for node in nodes:
|
||||
g_node = graph.graph[node.originalPos]
|
||||
for neighbor in g_node.neighbor:
|
||||
if neighbor in node_map:
|
||||
node.neighbors.append(node_map[neighbor])
|
||||
|
||||
return nodes
|
||||
|
||||
def weisfeiler_lehman_iteration(nodes: List[WLNode], iterations: int, decay: float = 0.6) -> Dict[str, float]:
|
||||
label_history = []
|
||||
for _ in range(iterations):
|
||||
new_labels = []
|
||||
for node in nodes:
|
||||
neighbor_labels = sorted(n.currentLabel for n in node.neighbors)
|
||||
composite = f"{node.currentLabel}|{','.join(neighbor_labels)}"[:8192]
|
||||
new_labels.append(composite)
|
||||
for node, new_label in zip(nodes, new_labels):
|
||||
node.currentLabel = new_label
|
||||
label_history.append(new_labels[:])
|
||||
|
||||
weight = 1.0
|
||||
label_counts = defaultdict(float)
|
||||
for layer in label_history:
|
||||
for label in layer:
|
||||
label_counts[label] += weight
|
||||
weight *= decay
|
||||
for node in nodes:
|
||||
label_counts[node.originalLabel] += weight
|
||||
return dict(label_counts)
|
||||
|
||||
def vectorize_features(features: Dict[str, float], vocab: List[str]) -> List[float]:
|
||||
vec = [0.0] * len(vocab)
|
||||
for label, count in features.items():
|
||||
if label in vocab:
|
||||
idx = vocab.index(label)
|
||||
vec[idx] += count
|
||||
return vec
|
||||
|
||||
def cosine_similarity(a: List[float], b: List[float]) -> float:
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = math.sqrt(sum(x * x for x in a))
|
||||
norm_b = math.sqrt(sum(y * y for y in b))
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
return dot / (norm_a * norm_b)
|
||||
|
||||
def wl_kernel(graph_a: GinkaGraph, graph_b: GinkaGraph, iterations: int = 3) -> float:
|
||||
nodes_a = encode_node_labels(graph_a)
|
||||
nodes_b = encode_node_labels(graph_b)
|
||||
features_a = weisfeiler_lehman_iteration(nodes_a, iterations)
|
||||
features_b = weisfeiler_lehman_iteration(nodes_b, iterations)
|
||||
vocab = list(set(features_a.keys()) | set(features_b.keys()))
|
||||
vec_a = vectorize_features(features_a, vocab)
|
||||
vec_b = vectorize_features(features_b, vocab)
|
||||
return cosine_similarity(vec_a, vec_b)
|
||||
|
||||
def overall_similarity(a: GinkaTopologicalGraphs, b: GinkaTopologicalGraphs) -> float:
|
||||
graphs_a = a.graphs
|
||||
graphs_b = b.graphs
|
||||
|
||||
total_similarity = 0.0
|
||||
compared_graphs: Set[GinkaGraph] = set()
|
||||
|
||||
for ga in graphs_a:
|
||||
max_similarity = 0.0
|
||||
max_graph = None
|
||||
for gb in graphs_b:
|
||||
if gb in compared_graphs:
|
||||
continue
|
||||
min_nodes = min(len(ga.graph), len(gb.graph))
|
||||
iterations = max(1, math.ceil(math.log(min_nodes)))
|
||||
similarity = wl_kernel(ga, gb, iterations)
|
||||
if similarity > max_similarity and not math.isnan(similarity):
|
||||
max_similarity = similarity
|
||||
max_graph = gb
|
||||
if similarity == 1:
|
||||
break
|
||||
total_similarity += max_similarity
|
||||
if max_graph:
|
||||
compared_graphs.add(max_graph)
|
||||
|
||||
reduction = 1 / (1 + abs(len(a.unreachable) - len(b.unreachable)))
|
||||
if not graphs_a:
|
||||
return 0.0
|
||||
return math.sqrt(total_similarity / len(graphs_a)) * reduction
|
||||
75
shared/similarity/vision.py
Normal file
75
shared/similarity/vision.py
Normal file
@ -0,0 +1,75 @@
|
||||
from typing import List, Dict
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
# 视觉相似度,由 ChatGPT-4o 从 ts 转译而来
|
||||
|
||||
class VisualSimilarityConfig:
|
||||
def __init__(self):
|
||||
self.type_weights: Dict[int, float] = {
|
||||
0: 0.2, 1: 0.3, 2: 0.6, 3: 0.7, 4: 0.7, 5: 0.5,
|
||||
6: 0.4, 7: 0.5, 8: 0.6, 9: 0.6, 10: 0.4, 11: 0.4, 12: 0.7
|
||||
}
|
||||
self.enable_visual_focus: bool = True
|
||||
self.enable_density_awareness: bool = True
|
||||
|
||||
def generate_focus_weights(rows: int, cols: int) -> List[List[float]]:
|
||||
weights = []
|
||||
center_x = cols / 2
|
||||
center_y = rows / 2
|
||||
for i in range(rows):
|
||||
row_weights = []
|
||||
for j in range(cols):
|
||||
dx = (j - center_x) / cols
|
||||
dy = (i - center_y) / rows
|
||||
distance = math.sqrt(dx ** 2 + dy ** 2)
|
||||
gaussian = math.exp(-(distance ** 2) / (2 * 0.3 ** 2))
|
||||
row_weights.append(1.0 + 0.6 * gaussian)
|
||||
weights.append(row_weights)
|
||||
return weights
|
||||
|
||||
def calculate_density_impact(map1: List[List[int]], map2: List[List[int]], type_weights: Dict[int, float]) -> List[List[float]]:
|
||||
rows, cols = len(map1), len(map1[0])
|
||||
density_map = [[0.0 for _ in range(cols)] for _ in range(rows)]
|
||||
window_size = 3
|
||||
half_window = window_size // 2
|
||||
|
||||
for i in range(rows):
|
||||
for j in range(cols):
|
||||
density = 0
|
||||
for di in range(-half_window, half_window + 1):
|
||||
for dj in range(-half_window, half_window + 1):
|
||||
ni, nj = i + di, j + dj
|
||||
if 0 <= ni < rows and 0 <= nj < cols:
|
||||
weight1 = type_weights.get(map1[ni][nj], 0.5)
|
||||
weight2 = type_weights.get(map2[ni][nj], 0.5)
|
||||
density += (weight1 + weight2) / 2
|
||||
density_map[i][j] = 1.0 + 0.4 * (density / (window_size ** 2))
|
||||
return density_map
|
||||
|
||||
def calculate_visual_similarity(map1: List[List[int]], map2: List[List[int]], config: VisualSimilarityConfig = None) -> float:
|
||||
if config is None:
|
||||
config = VisualSimilarityConfig()
|
||||
|
||||
if len(map1) != len(map2) or len(map1[0]) != len(map2[0]):
|
||||
return 0.0
|
||||
|
||||
rows, cols = len(map1), len(map1[0])
|
||||
total_score = 0.0
|
||||
max_possible_score = 0.0
|
||||
|
||||
focus_weights = generate_focus_weights(rows, cols) if config.enable_visual_focus else [[1.0 for _ in range(cols)] for _ in range(rows)]
|
||||
density_map = calculate_density_impact(map1, map2, config.type_weights) if config.enable_density_awareness else [[1.0 for _ in range(cols)] for _ in range(rows)]
|
||||
|
||||
for i in range(rows):
|
||||
for j in range(cols):
|
||||
type1 = map1[i][j]
|
||||
type2 = map2[i][j]
|
||||
base_weight = max(config.type_weights.get(type1, 0.5), config.type_weights.get(type2, 0.5))
|
||||
spatial_weight = focus_weights[i][j] * density_map[i][j]
|
||||
type_score = 1.0 if type1 == type2 else 0.0
|
||||
|
||||
total_score += type_score * base_weight * spatial_weight
|
||||
max_possible_score += base_weight * spatial_weight
|
||||
|
||||
return total_score / max_possible_score if max_possible_score > 0 else 0.0
|
||||
Loading…
Reference in New Issue
Block a user