From 29cfb4d029cfbada999e82acc3bb750bbbcf7045 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Sun, 6 Apr 2025 18:44:18 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E6=94=B9=E4=B8=BA=20Wasserstein=20?= =?UTF-8?q?GAN?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data/src/gan.ts | 24 ++--- data/src/process/minamo.ts | 75 ++++++++------- data/src/topology/similarity.ts | 3 +- ginka/dataset.py | 20 ++++ ginka/model/loss.py | 123 +++++++++++++++++++++++- ginka/model/model.py | 17 ++-- ginka/model/unet.py | 136 +++++++++++++++----------- ginka/train_gan.py | 122 ++++++++++++++++++------ ginka/train_wgan.py | 164 ++++++++++++++++++++++++++++++++ minamo/model/model.py | 61 ++++++++++++ minamo/model/topo.py | 46 ++++----- minamo/model/vision.py | 26 +++-- shared/constant.py | 6 ++ shared/graph.py | 3 + 14 files changed, 643 insertions(+), 183 deletions(-) create mode 100644 ginka/train_wgan.py create mode 100644 shared/constant.py diff --git a/data/src/gan.ts b/data/src/gan.ts index 4edd797..24ec44e 100644 --- a/data/src/gan.ts +++ b/data/src/gan.ts @@ -4,7 +4,7 @@ import { MinamoTrainData } from './types'; import { generateTrainData } from './process/minamo'; const SOCKET_FILE = '../tmp/ginka_uds'; -const [refer] = process.argv.slice(2); +const [refer, replayPath = '../datasets/replay.bin'] = process.argv.slice(2); let id = 0; @@ -43,7 +43,7 @@ function generateGANData( const size2: [number, number] = [map[0].length, map.length]; if (size1[0] !== size2[0] || size1[1] !== size2[1]) return []; - return generateTrainData(v, id2, floor.map, map, size1); + return generateTrainData(v, id2, floor.map, map, size1, false, false, false); }); return data.flat(); } @@ -104,7 +104,7 @@ class DataReceiver { console.log(`UDS IPC connected successfully.`); }); - client.on('data', buffer => { + client.on('data', async buffer => { const data = DataReceiver.check(buffer); if (!data) return; @@ -112,20 +112,19 @@ class DataReceiver { const simData = map.map(v => generateGANData(keys, referTower, v)); const rc = 0; const compareData = simData.flat(); - const reviewData: MinamoTrainData[] = []; // 数据通讯 node 输出协议,单位字节: - // 2 - Tensor count; 2 - Review count. Review is right behind train data; + // 2 - Tensor count; 2 - Replay count. Replay is right behind train data; // 1*tc - Compare count for every map tensor delivered. // 2*4*(N+rc) - Vision similarity and topo similarity, like vis, topo, vis, topo; - // N*1*H*W - Compare map for every map tensor. rc*2*H*W - Review map tensor. + // N*1*H*W - Compare map for every map tensor. rc*2*H*W - Replay map tensor. const toSend = Buffer.alloc( 2 + // Tensor count - 2 + // Review count + 2 + // Replay count 1 * count + // Compare count 2 * 4 * (compareData.length + rc) + // Similarity data compareData.length * 1 * h * w + // Compare map - rc * 2 * h * w, // Review map + rc * 2 * h * w, // Replay map 0 ); console.log( @@ -141,7 +140,7 @@ class DataReceiver { let offset = 0; toSend.writeInt16BE(count); // Tensor count - toSend.writeInt16BE(0, 2); // Review count + toSend.writeInt16BE(0, 2); // Replay count offset += 2 + 2; // Compare count toSend.set( @@ -164,13 +163,6 @@ class DataReceiver { offset // Set from Compare map ); offset += compareData.length * 1 * h * w; - if (reviewData.length > 0) { - // Review map - toSend.set( - new Uint8Array(reviewData.map(v => [v.map1, v.map2]).flat(4)), - offset // Set from last chunk - ); - } client.write(toSend); }); diff --git a/data/src/process/minamo.ts b/data/src/process/minamo.ts index 40471b4..085b009 100644 --- a/data/src/process/minamo.ts +++ b/data/src/process/minamo.ts @@ -192,7 +192,10 @@ export function generateTrainData( id2: string, map1: number[][], map2: number[][], - size: [number, number] + size: [number, number], + hasSelf: boolean = true, + hasTransform: boolean = true, + hasSimilar: boolean = true ) { const topoSimilarity = compareMap(id1, id2, map1, map2); const visionSimilarity = calculateVisualSimilarity(map1, map2); @@ -205,39 +208,47 @@ export function generateTrainData( }; const data: MinamoTrainData[] = []; data.push(train); - // 自身与自身对比的训练集,保证模型对相同地图输出 1 - const self1 = `${id1}:${id1}`; - const self2 = `${id2}:${id2}`; - const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1)); - if (selfTrain.includes(self1)) { - const selfTrain1: MinamoTrainData = { - map1: map1, - map2: map1, - topoSimilarity: 1, - visionSimilarity: 1, - size: size - }; - data.push(selfTrain1); + if (hasSelf) { + // 自身与自身对比的训练集,保证模型对相同地图输出 1 + const self1 = `${id1}:${id1}`; + const self2 = `${id2}:${id2}`; + const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1)); + if (selfTrain.includes(self1)) { + const selfTrain1: MinamoTrainData = { + map1: map1, + map2: map1, + topoSimilarity: 1, + visionSimilarity: 1, + size: size + }; + data.push(selfTrain1); + } + if (selfTrain.includes(self2)) { + const selfTrain2: MinamoTrainData = { + map1: map2, + map2: map2, + topoSimilarity: 1, + visionSimilarity: 1, + size: size + }; + data.push(selfTrain2); + } } - if (selfTrain.includes(self2)) { - const selfTrain2: MinamoTrainData = { - map1: map2, - map2: map2, - topoSimilarity: 1, - visionSimilarity: 1, - size: size - }; - data.push(selfTrain2); + if (hasTransform) { + const transform = generateTransformData( + id1, + id2, + map1, + map2, + topoSimilarity + ); + data.push(...transform.map(v => v[1])) } - const transform = generateTransformData( - id1, - id2, - map1, - map2, - topoSimilarity - ); - const similar = generateSimilarData(id1, map1); - return [...data, ...transform.map(v => v[1]), ...similar.map(v => v[1])]; + if (hasSimilar) { + const similar = generateSimilarData(id1, map1); + data.push(...similar.map(v => v[1])) + } + return data; } export function generatePair( diff --git a/data/src/topology/similarity.ts b/data/src/topology/similarity.ts index abb345b..9275d3c 100644 --- a/data/src/topology/similarity.ts +++ b/data/src/topology/similarity.ts @@ -87,8 +87,7 @@ function weisfeilerLehmanIteration( }); weight *= decay; }); - // 把每个节点的原始标签也加上,权重使用最远权重再衰减1次,可以认为是资源重复率 - weight *= decay; + // 把每个节点的原始标签也加上,权重使用最远权重,可以认为是资源重复率 nodes.forEach(node => { if (!numMap.has(node.originalLabel)) { numMap.set(node.originalLabel, weight); diff --git a/ginka/dataset.py b/ginka/dataset.py index 478d779..265beb0 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -51,6 +51,25 @@ class GinkaDataset(Dataset): "target": target, } +class GinkaWGANDataset(Dataset): + def __init__(self, data_path: str, device): + self.data = load_data(data_path) # 自定义数据加载函数 + self.device = device + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + item = self.data[idx] + + target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] + min_main = random.uniform(0.75, 0.9) + max_main = random.uniform(0.9, 1) + epsilon = random.uniform(0, 0.25) + target_smooth = random_smooth_onehot(target, min_main, max_main, epsilon).to(self.device) + + return target_smooth + class MinamoGANDataset(Dataset): def __init__(self, refer_data_path): self.refer = load_minamo_gan_data(load_data(refer_data_path)) @@ -67,6 +86,7 @@ class MinamoGANDataset(Dataset): return len(self.data) def __getitem__(self, idx): + # 假定 map2 是参考地图 item = self.data[idx] map1, map2, vis_sim, topo_sim, review = item diff --git a/ginka/model/loss.py b/ginka/model/loss.py index 38bdcf1..d50842d 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -3,8 +3,10 @@ from tqdm import tqdm import torch import torch.nn as nn import torch.nn.functional as F +from torch_geometric.data import Data from minamo.model.model import MinamoModel from shared.graph import batch_convert_soft_map_to_graph +from shared.constant import VISION_WEIGHT, TOPO_WEIGHT CLASS_NUM = 32 ILLEGAL_MAX_NUM = 12 @@ -260,10 +262,129 @@ class GinkaLoss(nn.Module): ) losses = [ - minamo_loss * self.weight[0], + minamo_loss * self.weight[0] * 4, class_loss * self.weight[1], entrance_loss * self.weight[2], count_loss * self.weight[3] ] return sum(losses) + +# 对图像数据进行插值 +def interpolate_data(real_data, fake_data, epsilon): + return epsilon * real_data + (1 - epsilon) * fake_data + +# 对节点特征进行插值,但保持边连接关系不变 +def interpolate_graph_features(real_graph, fake_graph, epsilon=0.5): + # 插值节点特征 + x_real, x_fake = real_graph.x, fake_graph.x + x_interp = epsilon * x_real + (1 - epsilon) * x_fake + + # 保持边连接关系和边特征不变 + edge_index_interp = real_graph.edge_index # 保持边连接关系 + edge_attr_interp = real_graph.edge_attr # 如果有边特征,保持不变 + + return Data(x=x_interp, edge_index=edge_index_interp, edge_attr=edge_attr_interp) + +def js_divergence(P, Q, epsilon=1e-10): + """ + 输入: + P, Q: [B, C, H, W], 已通过 Softmax 处理 + 输出: + JS 散度标量(全局平均) + """ + # 转换为 [B, H, W, C] 以便在最后一维计算概率分布 + P = P.permute(0, 2, 3, 1) # [B, H, W, C] + Q = Q.permute(0, 2, 3, 1) + + # 平均分布 M = (P + Q)/2 + M = 0.5 * (P + Q) + + # 计算 KL(P||M) 和 KL(Q||M) + kl_pm = F.kl_div(torch.log(M + epsilon), P, reduction='none', log_target=False).sum(dim=-1) # [B, H, W] + kl_qm = F.kl_div(torch.log(M + epsilon), Q, reduction='none', log_target=False).sum(dim=-1) # [B, H, W] + + # JS 散度 = 0.5*(KL(P||M) + KL(Q||M)) + js = 0.5 * (kl_pm + kl_qm) + + # 全局平均(可替换为其他聚合方式) + return js.mean() # 标量 + +class WGANGinkaLoss: + def __init__(self, lambda_gp=10, weight=[0.7, 0.2, 0.1], diversity_lamda=0): + self.lambda_gp = lambda_gp # 梯度惩罚系数 + self.weight = weight + self.diversity_lamda = diversity_lamda + + def compute_gradient_penalty(self, critic, real_data, fake_data): + # 进行插值 + batch_size = real_data.size(0) + epsilon_data = torch.randn(batch_size, 1, 1, 1, device=real_data.device) + interp_data = interpolate_data(real_data, fake_data, epsilon_data) + interp_graph = batch_convert_soft_map_to_graph(interp_data) + + # 对图像进行反向传播并计算梯度 + interp_data.requires_grad_() + interp_graph.x.requires_grad_() + + _, d_vis_score, d_topo_score = critic(interp_data, interp_graph) + + # 计算梯度 + grad_vis = torch.autograd.grad( + outputs=d_vis_score, inputs=interp_data, + grad_outputs=torch.ones_like(d_vis_score), + create_graph=True, retain_graph=True, only_inputs=True + )[0] + grad_topo = torch.autograd.grad( + outputs=d_topo_score, inputs=interp_graph.x, + grad_outputs=torch.ones_like(d_topo_score), + create_graph=True, retain_graph=True, only_inputs=True + )[0] + + # 计算梯度的 L2 范数 + grad_norm_vis = grad_vis.view(batch_size, -1).norm(2, dim=1) + grad_norm_topo = grad_topo.view(batch_size, -1).norm(2, dim=1) + # 计算梯度惩罚项 + gp_loss_vis = ((grad_norm_vis - 1.0) ** 2).mean() + gp_loss_topo = ((grad_norm_topo - 1.0) ** 2).mean() + gp_loss = gp_loss_vis * VISION_WEIGHT + gp_loss_topo * TOPO_WEIGHT + # print(grad_norm_topo.mean().item(), grad_norm_vis.mean().item()) + + return gp_loss + + def discriminator_loss( + self, critic, real_data: torch.Tensor, + real_graph: torch.Tensor, fake_data: torch.Tensor + ): + """ 判别器损失函数 """ + fake_graph = batch_convert_soft_map_to_graph(fake_data) + real_scores, _, _ = critic(real_data, real_graph) + fake_scores, _, _ = critic(fake_data, fake_graph) + + # Wasserstein 距离 + d_loss = fake_scores.mean() - real_scores.mean() + grad_loss = self.compute_gradient_penalty(critic, real_data, fake_data) + + return d_loss, d_loss + self.lambda_gp * grad_loss + + def generator_loss_one(self, critic, fake): + fake_graph = batch_convert_soft_map_to_graph(fake) + fake_scores, _, _ = critic(fake, fake_graph) + minamo_loss = -torch.mean(fake_scores) + class_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake) + entrance_loss = entrance_constraint_loss(fake) + + losses = [ + minamo_loss * self.weight[0], + class_loss * self.weight[1], + entrance_loss * self.weight[2] + ] + + return sum(losses) + + def generator_loss(self, critic, fake1, fake2): + """ 生成器损失函数 """ + loss1 = self.generator_loss_one(critic, fake1) + loss2 = self.generator_loss_one(critic, fake2) + + return loss1 * 0.5 + loss2 * 0.5 - self.diversity_lamda * js_divergence(fake1, fake2) diff --git a/ginka/model/model.py b/ginka/model/model.py index 0c75ba4..515dac8 100644 --- a/ginka/model/model.py +++ b/ginka/model/model.py @@ -3,7 +3,7 @@ import torch.nn as nn import torch.nn.functional as F from .unet import GinkaUNet from .output import GinkaOutput -from .input import GinkaInput, GinkaFeatureInput +from .input import GinkaInput def print_memory(tag=""): print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB") @@ -13,9 +13,7 @@ class GinkaModel(nn.Module): """Ginka Model 模型定义部分 """ super().__init__() - self.input = GinkaInput(feat_dim, 1, (32, 32)) - self.feat_enc = GinkaFeatureInput(feat_dim, 2, base_ch) - self.unet = GinkaUNet(1, base_ch, base_ch, feat_dim) + self.unet = GinkaUNet(base_ch, base_ch, feat_dim) self.output = GinkaOutput(base_ch, out_ch, (13, 13)) def forward(self, x): @@ -25,12 +23,9 @@ class GinkaModel(nn.Module): Returns: logits: 输出logits [BS, num_classes, H, W] """ - cond = x - feat = self.feat_enc(x) - x = self.input(x) - x = self.unet(x, feat, cond) + x = self.unet(x) x = self.output(x) - return x, F.softmax(x, dim=1) + return F.softmax(x, dim=1) # 检查显存占用 if __name__ == "__main__": @@ -48,8 +43,8 @@ if __name__ == "__main__": print(f"输入形状: feat={feat.shape}") print(f"输出形状: output={output.shape}, softmax={output_softmax.shape}") - print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}") - print(f"Feature Encoder parameters: {sum(p.numel() for p in model.feat_enc.parameters())}") + # print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}") + # print(f"Feature Encoder parameters: {sum(p.numel() for p in model.feat_enc.parameters())}") print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}") print(f"Output parameters: {sum(p.numel() for p in model.output.parameters())}") print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") diff --git a/ginka/model/unet.py b/ginka/model/unet.py index e6de53b..2d52ff2 100644 --- a/ginka/model/unet.py +++ b/ginka/model/unet.py @@ -1,40 +1,51 @@ import torch import torch.nn as nn import torch.nn.functional as F +from shared.constant import FEAT_DIM + +class GinkaTransformerEncoder(nn.Module): + def __init__(self, in_dim, hidden_dim, out_dim, token_size, ff_dim, num_heads=8, num_layers=6): + super().__init__() + in_dim = in_dim // token_size + hidden_dim = hidden_dim // token_size + out_dim = out_dim // token_size + self.embedding = nn.Sequential( + nn.Linear(in_dim, hidden_dim), + nn.LayerNorm(hidden_dim) + ) + self.pos_embedding = nn.Parameter(torch.randn(1, token_size, hidden_dim)) + self.transformer = nn.TransformerEncoder( + nn.TransformerEncoderLayer(hidden_dim, num_heads, dim_feedforward=ff_dim, batch_first=True), + num_layers=num_layers + ) + self.fc = nn.Sequential( + nn.Linear(hidden_dim, out_dim), + nn.LayerNorm(out_dim) + ) + + def forward(self, x): + # 输入 [B, L, in_dim] + # 输出 [B, L, out_dim] + x = self.embedding(x) # [B, L, hidden_dim] + x = x + self.pos_embedding # [B, L, hidden_dim] + x = self.transformer(x) # [B, L, hidden_dim] + x = self.fc(x) # [B, L, out_dim] + return x class ConvBlock(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.conv = nn.Sequential( - nn.Conv2d(in_ch, out_ch, 3, padding=1), - nn.BatchNorm2d(out_ch), - nn.ReLU(), - nn.Conv2d(out_ch, out_ch, 3, padding=1), - nn.BatchNorm2d(out_ch), - nn.ReLU(), + nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'), + nn.InstanceNorm2d(out_ch), + nn.ELU(), + nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'), + nn.InstanceNorm2d(out_ch), + nn.ELU(), ) def forward(self, x): return self.conv(x) - -class ConditionFusionBlock(nn.Module): - def __init__(self): - super().__init__() - self.alpha = nn.Parameter(torch.tensor(0.5)) # 可学习融合系数 - - def forward(self, x, cond_feat): - return x + self.alpha * cond_feat # 残差融合 - -class FusionConvBlock(nn.Module): - def __init__(self, in_ch, out_ch): - super().__init__() - self.conv = ConvBlock(in_ch, out_ch) - self.fusion = ConditionFusionBlock() - - def forward(self, x, feat): - x = self.conv(x) - x = self.fusion(x, feat) - return x class GinkaEncoder(nn.Module): """编码器(下采样)部分""" @@ -42,12 +53,10 @@ class GinkaEncoder(nn.Module): super().__init__() self.conv = ConvBlock(in_ch, out_ch) self.pool = nn.MaxPool2d(2) - self.fusion = ConditionFusionBlock() - def forward(self, x, feat): + def forward(self, x): x = self.conv(x) x = self.pool(x) - x = self.fusion(x, feat) return x class GinkaUpSample(nn.Module): @@ -55,8 +64,8 @@ class GinkaUpSample(nn.Module): super().__init__() self.conv = nn.Sequential( nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2), - nn.BatchNorm2d(out_ch), - nn.GELU(), + nn.InstanceNorm2d(out_ch), + nn.ELU(), ) def forward(self, x): @@ -68,46 +77,57 @@ class GinkaDecoder(nn.Module): super().__init__() self.upsample = GinkaUpSample(in_ch, in_ch // 2) self.conv = ConvBlock(in_ch, out_ch) - self.fusion = ConditionFusionBlock() - def forward(self, x, skip, feat): - dec = self.upsample(x) - x = torch.cat([dec, skip], dim=1) + def forward(self, x, feat): + x = self.upsample(x) + x = torch.cat([x, feat], dim=1) x = self.conv(x) - x = self.fusion(x, feat) return x class GinkaUNet(nn.Module): - def __init__(self, in_ch=1, base_ch=64, out_ch=32, feat_dim=1024): + def __init__(self, base_ch=64, out_ch=32, feat_dim=1024): """Ginka Model UNet 部分 """ super().__init__() - self.in_conv = FusionConvBlock(in_ch, base_ch) - self.down1 = GinkaEncoder(base_ch, base_ch*2) - self.down2 = GinkaEncoder(base_ch*2, base_ch*4) - self.down3 = GinkaEncoder(base_ch*4, base_ch*8) - - self.bottleneck = GinkaEncoder(base_ch*8, base_ch*16) - - self.up1 = GinkaDecoder(base_ch*16, base_ch*8) - self.up2 = GinkaDecoder(base_ch*8, base_ch*4) - self.up3 = GinkaDecoder(base_ch*4, base_ch*2) - self.up4 = GinkaDecoder(base_ch*2, base_ch) + self.input = GinkaTransformerEncoder( + in_dim=feat_dim, hidden_dim=feat_dim*2, out_dim=2*32*32, # 自动除以 token_size + token_size=4, ff_dim=feat_dim*2, num_layers=4 + ) + self.down1 = ConvBlock(2, base_ch) + self.down2 = GinkaEncoder(base_ch, base_ch*2) + self.down3 = GinkaEncoder(base_ch*2, base_ch*4) + self.down4 = GinkaEncoder(base_ch*4, base_ch*8) + self.bottleneck = GinkaTransformerEncoder( + in_dim=base_ch*8*4*4, hidden_dim=base_ch*8*4*4, out_dim=base_ch*8*4*4, + token_size=16, ff_dim=1024, num_layers=4 + ) + + self.up1 = GinkaDecoder(base_ch*8, base_ch*4) + self.up2 = GinkaDecoder(base_ch*4, base_ch*2) + self.up3 = GinkaDecoder(base_ch*2, base_ch) self.final = nn.Sequential( nn.Conv2d(base_ch, out_ch, 1), + nn.InstanceNorm2d(out_ch), + nn.ELU(), ) - def forward(self, x, feat, cond): - x1 = self.in_conv(x, feat[0]) - x2 = self.down1(x1, feat[1]) - x3 = self.down2(x2, feat[2]) - x4 = self.down3(x3, feat[3]) - x5 = self.bottleneck(x4, feat[4]) + def forward(self, x): + B, D = x.shape # [B, 1024] + x = x.view(B, 4, D // 4) # [B, 4, 256] + x = self.input(x) # [B, 4, 512] + x = x.view(B, 2, 32, 32) # [B, 2, 32, 32] + x1 = self.down1(x) # [B, 64, 32, 32] + x2 = self.down2(x1) # [B, 128, 16, 16] + x3 = self.down3(x2) # [B, 256, 8, 8] + x4 = self.down4(x3) # [B, 512, 4, 4] + x4 = x4.view(B, 512, 16).permute(0, 2, 1) # [B, 16, 512] + x4 = self.bottleneck(x4) # [B, 16, 512] + x4 = x4.permute(0, 2, 1).view(B, 512, 4, 4) # [B, 512, 4, 4] - x = self.up1(x5, x4, feat[3]) - x = self.up2(x, x3, feat[2]) - x = self.up3(x, x2, feat[1]) - x = self.up4(x, x1, feat[0]) + # 上采样 + x = self.up1(x4, x3) # [B, 256, 8, 8] + x = self.up2(x, x2) # [B, 128, 16, 16] + x = self.up3(x, x1) # [B, 64, 32, 32] - return self.final(x) + return self.final(x) # [B, 32, 32, 32] diff --git a/ginka/train_gan.py b/ginka/train_gan.py index ee9ba08..51a3af8 100644 --- a/ginka/train_gan.py +++ b/ginka/train_gan.py @@ -11,17 +11,19 @@ from tqdm import tqdm import cv2 import numpy as np from .model.model import GinkaModel -from .model.loss import GinkaLoss +from .model.loss import GinkaLoss, WGANGinkaLoss from .dataset import GinkaDataset, MinamoGANDataset from minamo.model.model import MinamoModel from minamo.model.loss import MinamoLoss from shared.image import matrix_to_image_cv BATCH_SIZE = 32 -EPOCHS_GINKA = 30 -EPOCHS_MINAMO = 5 +EPOCHS_GINKA = 5 +EPOCHS_MINAMO = 2 SOCKET_PATH = "./tmp/ginka_uds" LOSS_PATH = "result/gan/a-loss.txt" +REPLAY_PATH = "datasets/replay.bin" +VISION_ALPHA = 0 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") os.makedirs("result", exist_ok=True) @@ -32,6 +34,10 @@ os.makedirs("tmp", exist_ok=True) with open(LOSS_PATH, 'a', encoding='utf-8') as f: f.write(f"---------- {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ----------\n") +if not os.path.exists(REPLAY_PATH): + with open(REPLAY_PATH, 'wb') as f: + f.write(b'\x00\x00\x00\x00') + def parse_arguments(): parser = argparse.ArgumentParser(description="training codes") parser.add_argument("--resume", type=bool, default=False) @@ -142,14 +148,16 @@ def train(): minamo_dataloader_val = DataLoader(minamo_dataset_val, batch_size=BATCH_SIZE // 2, shuffle=True) # 设定优化器与调度器 - optimizer_ginka = optim.AdamW(ginka.parameters(), lr=1e-3) + optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9)) scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2, eta_min=1e-6) criterion_ginka = GinkaLoss(minamo) - optimizer_minamo = optim.AdamW(minamo.parameters(), lr=1e-4) - scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=5, T_mult=2, eta_min=1e-6) + optimizer_minamo = optim.Adam(minamo.parameters(), lr=2e-5, betas=(0.0, 0.9)) + scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=EPOCHS_MINAMO, T_mult=2, eta_min=1e-6) criterion_minamo = MinamoLoss() + criterion = WGANGinkaLoss() + # 用于生成图片 tile_dict = dict() for file in os.listdir('tiles'): @@ -168,27 +176,17 @@ def train(): ginka.load_state_dict(data["model_state"], strict=False) print("Train from loaded state.") - else: - # 从头开始训练的话,初始时先把 minamo 损失值权重改为 0 - criterion_ginka.weight[0] = 0.0 - print("Waiting for client connection...") conn, _ = server.accept() print("Client connected.") for cycle in tqdm(range(args.from_cycle, args.to_cycle), desc="Total Progress"): # -------------------- 训练生成器 - gen_list: np.ndarray = np.empty((0, 13, 13), np.int8) - prob_list: np.ndarray = np.empty((0, 32, 13, 13), np.float32) - for epoch in tqdm(range(EPOCHS_GINKA), desc="Training Ginka Model"): + for epoch in tqdm(range(EPOCHS_GINKA), desc="Training Ginka Model", leave=False): ginka.train() minamo.eval() total_loss = 0 - # 从头开始训练的,在第 10 个 epoch 将 minamo 损失值权重改回来 - if not args.resume and epoch == 10: - criterion_ginka.weight[0] = 0.5 - for batch in tqdm(ginka_dataloader, leave=False, desc="Epoch Progress"): # 数据迁移到设备 target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch) @@ -233,6 +231,8 @@ def train(): }, f"result/ginka_checkpoint/{epoch + 1}.pth") # 使用训练集生成 minamo 训练数据,更准确 + gen_list: np.ndarray = np.empty((0, 13, 13), np.int8) + prob_list: np.ndarray = np.empty((0, 32, 13, 13), np.float32) with torch.no_grad(): for batch in ginka_dataloader: target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch) @@ -263,7 +263,7 @@ def train(): buf.extend(gen_bytes) # Map tensor conn.sendall(buf) data = parse_minamo_data(conn, prob_list) - minamo_dataset.set_data(data) + vis_sim = 0 topo_sim = 0 for _, _, vis, topo, _ in data: @@ -276,6 +276,44 @@ def train(): with open(LOSS_PATH, 'a', encoding='utf-8') as f: f.write(f'Cycle {cycle} | Ginka Vision Similarity: {vis_sim:.12f} | Ginka Topo Similarity: {topo_sim:.12f} | Ginka Loss: {avg_val_loss:.12f}') + # 经验回放部分 + with open(REPLAY_PATH, 'r+b') as f: + # 读取文件开头获取总长度 + f.seek(0) + count = struct.unpack('>i', f.read(4))[0] # 取出整数 + if count > 0: + replay = np.random.choice(count, size=min(count, len(data) // 4), replace=False) + + replay_data = np.empty((len(replay), 32, 13, 13)) + for i, n in enumerate(replay): + f.seek(n * 32 * 13 * 13 + 4) + arr = np.frombuffer(f.read(32 * 13 * 13 * 4), dtype=np.float32).reshape(32, 13, 13) + replay_data[i] = arr + + map_data: np.ndarray = replay_data.argmax(axis=1) + buf = bytearray() + buf.extend(struct.pack('>h', len(replay))) # Tensor count + buf.extend(struct.pack('>b', H)) # Map height + buf.extend(struct.pack('>b', W)) # Map width + buf.extend(map_data.astype(np.int8).tobytes()) # Map tensor + conn.sendall(buf) + data.extend(parse_minamo_data(conn, replay_data)) + + # 把新的内容写入文件末尾 + to_write = np.random.choice(N, size=min(N, 100), replace=False) + write_data = bytearray() + for n in to_write: + write_data.extend(prob_list[n].tobytes()) + + f.seek(0, 2) # 定位到文件末尾 + f.write(write_data) + + f.seek(0) # 定位到文件开头 + f.write(struct.pack('>i', count + len(to_write))) + f.flush() # 确保数据被刷新到磁盘 + + minamo_dataset.set_data(data) + # -------------------- 训练判别器 for epoch in tqdm(range(EPOCHS_MINAMO), leave=False, desc="Training Minamo Model"): ginka.eval() @@ -283,21 +321,43 @@ def train(): total_loss = 0 for batch in tqdm(minamo_dataloader, leave=False, desc="Epoch Progress"): - map1, map2, vision_simi, topo_simi, graph1, graph2 = parse_minamo_batch(batch) + map1, map2, vis_sim, topo_sim, graph1, graph2 = parse_minamo_batch(batch) + batch_size = map1.shape[0] - if map1.shape[0] == 1: + if batch_size == 1: continue # 前向传播 optimizer_minamo.zero_grad() - vision_feat1, topo_feat1 = minamo(map1, graph1) - vision_feat2, topo_feat2 = minamo(map2, graph2) + vis_feat_real, topo_feat_real = minamo(map1, graph1) + vis_feat_ref, topo_feat_ref = minamo(map2, graph2) - vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, dim=1).unsqueeze(-1) - topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, dim=1).unsqueeze(-1) + # 生成假数据 + with torch.no_grad(): + fake_feat = torch.randn((batch_size, 1024), device=device) + fake_data = ginka(fake_feat) + + # 创建插值样本 + alpha = torch.rand((batch_size, 1, 1, 1), device=device) + interpolates = (alpha * map2 + (1 - alpha) * fake_data).requires_grad_(True) + + vis_feat_fake, topo_feat_fake = minamo(fake_data) + vis_feat_interp, topo_feat_interp = minamo(interpolates) + + vis_pred_real = F.cosine_similarity(vis_feat_real, vis_feat_ref, dim=1).unsqueeze(-1) + topo_pred_real = F.cosine_similarity(topo_feat_real, topo_feat_ref, dim=1).unsqueeze(-1) + vis_pred_fake = F.cosine_similarity(vis_feat_fake, vis_feat_ref, dim=1).unsqueeze(-1) + topo_pred_fake = F.cosine_similarity(topo_feat_fake, topo_feat_ref, dim=1).unsqueeze(-1) + vis_pred_interp = F.cosine_similarity(vis_feat_interp, vis_feat_ref, dim=1).unsqueeze(-1) + topo_pred_interp = F.cosine_similarity(topo_feat_interp, topo_feat_ref, dim=1).unsqueeze(-1) + + # 计算相似度 + score_real = F.l1_loss(vis_pred_real, vis_sim) * VISION_ALPHA + F.l1_loss(topo_pred_real, topo_sim) * (1 - VISION_ALPHA) + score_fake = vis_pred_fake * VISION_ALPHA + topo_pred_fake * (1 - VISION_ALPHA) + score_interp = vis_pred_interp * VISION_ALPHA + topo_pred_interp * (1 - VISION_ALPHA) # 计算损失 - loss = criterion_minamo(vision_pred, topo_pred, vision_simi, topo_simi) + loss = criterion.discriminator_loss(score_real, score_fake, score_interp) # 反向传播 loss.backward() @@ -310,21 +370,21 @@ def train(): scheduler_minamo.step(epoch + 1) # 每十轮推理一次验证集 - if (epoch + 1) % 5 == 0: + if epoch + 1 == EPOCHS_MINAMO: minamo.eval() val_loss = 0 with torch.no_grad(): for val_batch in tqdm(minamo_dataloader_val, leave=False, desc="Validating Minamo Model"): map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = parse_minamo_batch(val_batch) - vision_feat1, topo_feat1 = minamo(map1_val, graph1) - vision_feat2, topo_feat2 = minamo(map2_val, graph2) + vis_feat_real, topo_feat_real = minamo(map1_val, graph1) + vis_feat_ref, topo_feat_ref = minamo(map2_val, graph2) - vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, dim=1).unsqueeze(-1) - topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, dim=1).unsqueeze(-1) + vis_pred_real = F.cosine_similarity(vis_feat_real, vis_feat_ref, dim=1).unsqueeze(-1) + topo_pred_real = F.cosine_similarity(topo_feat_real, topo_feat_ref, dim=1).unsqueeze(-1) # 计算损失 - loss_val = criterion_minamo(vision_pred, topo_pred, vision_simi_val, topo_simi_val) + loss_val = criterion_minamo(vis_pred_real, topo_pred_real, vision_simi_val, topo_simi_val) val_loss += loss_val.item() avg_val_loss = val_loss / len(minamo_dataloader_val) diff --git a/ginka/train_wgan.py b/ginka/train_wgan.py new file mode 100644 index 0000000..ac70e9c --- /dev/null +++ b/ginka/train_wgan.py @@ -0,0 +1,164 @@ +import argparse +import os +from datetime import datetime +import torch +import torch.optim as optim +import cv2 +from torch_geometric.loader import DataLoader +from tqdm import tqdm +from .model.model import GinkaModel +from .dataset import GinkaWGANDataset +from .model.loss import WGANGinkaLoss +from minamo.model.model import MinamoScoreModule +from shared.graph import batch_convert_soft_map_to_graph +from shared.image import matrix_to_image_cv +from shared.constant import VISION_WEIGHT, TOPO_WEIGHT + +BATCH_SIZE = 32 + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +os.makedirs("result", exist_ok=True) +os.makedirs("result/wgan", exist_ok=True) + +def parse_arguments(): + parser = argparse.ArgumentParser(description="training codes") + parser.add_argument("--resume", type=bool, default=False) + parser.add_argument("--state_ginka", type=str, default="result/ginka.pth") + parser.add_argument("--state_minamo", type=str, default="result/minamo.pth") + parser.add_argument("--train", type=str, default="ginka-dataset.json") + parser.add_argument("--epochs", type=int, default=100) + args = parser.parse_args() + return args + +def clip_weights(model, clip_value=0.01): + for param in model.parameters(): + param.data = torch.clamp(param.data, -clip_value, clip_value) + +def train(): + print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") + + c_steps = 1 + g_steps = 3 + + args = parse_arguments() + + ginka = GinkaModel() + minamo = MinamoScoreModule() + ginka.to(device) + minamo.to(device) + + dataset = GinkaWGANDataset(args.train, device) + dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) + + optimizer_ginka = optim.RMSprop(ginka.parameters(), lr=2e-4) + optimizer_minamo = optim.RMSprop(minamo.parameters(), lr=1e-5) + + criterion = WGANGinkaLoss() + + # 用于生成图片 + tile_dict = dict() + for file in os.listdir('tiles'): + name = os.path.splitext(file)[0] + tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED) + + if args.resume: + data = torch.load(args.state_ginka, map_location=device) + ginka.load_state_dict(data["model_state"], strict=False) + data = torch.load(args.state_minamo, map_location=device) + minamo.load_state_dict(data["model_state"], strict=False) + print("Train from loaded state.") + + for epoch in tqdm(range(args.epochs), desc="GAN Training"): + loss_total_minamo = torch.Tensor([0]).to(device) + loss_total_ginka = torch.Tensor([0]).to(device) + dis_total = torch.Tensor([0]).to(device) + + for real_data in tqdm(dataloader, leave=False, desc="Epoch Progress"): + batch_size = real_data.size(0) + real_data = real_data.to(device) + real_graph = batch_convert_soft_map_to_graph(real_data) + + optimizer_ginka.zero_grad() + + # ---------- 训练判别器 + for _ in range(c_steps): + # 生成假样本 + optimizer_minamo.zero_grad() + z = torch.randn(batch_size, 1024, device=device) + fake_data = ginka(z) + fake_data = fake_data.detach() + + # 计算判别器输出 + # 反向传播 + dis, loss_d = criterion.discriminator_loss(minamo, real_data, real_graph, fake_data) + loss_d.backward() + # torch.nn.utils.clip_grad_norm_(minamo_vis.parameters(), max_norm=1.0) + optimizer_minamo.step() + + loss_total_minamo += loss_d + dis_total += dis + + # ---------- 训练生成器 + + for _ in range(g_steps): + z1 = torch.randn(batch_size, 1024, device=device) + z2 = torch.randn(batch_size, 1024, device=device) + fake_softmax1, fakse_softmax2 = ginka(z1), ginka(z2) + + loss_g = criterion.generator_loss(minamo, fake_softmax1, fakse_softmax2) + loss_g.backward() + optimizer_ginka.step() + + loss_total_ginka += loss_g + # tqdm.write(f"{dis.item():.12f}, {loss_d.item():.12f}, {loss_g.item():.12f}") + + avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps + avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps + avg_dis = dis_total.item() / len(dataloader) / c_steps + tqdm.write( + f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | Wasserstein Loss: {avg_dis:.8f} | Loss Ginka: {avg_loss_ginka:.8f} | Loss Minamo: {avg_loss_minamo:.8f}" + ) + + if avg_dis < -9: + g_steps = 7 + elif avg_dis < -6: + g_steps = 5 + elif avg_dis < -3: + g_steps = 3 + else: + g_steps = 1 + + # 每五轮输出一次图片,并保存检查点 + if (epoch + 1) % 5 == 0: + # 输出 20 张图片,每批次 4 张,一共五批 + idx = 0 + with torch.no_grad(): + for _ in range(5): + z = torch.randn(4, 1024, device=device) + output = ginka(z) + + map_matrix = torch.argmax(output, dim=1).cpu().numpy() + for matrix in map_matrix: + image = matrix_to_image_cv(matrix, tile_dict) + cv2.imwrite(f"result/ginka_img/{idx}.png", image) + idx += 1 + + # 保存检查点 + torch.save({ + "model_state": ginka.state_dict() + }, f"result/wgan/ginka-{epoch + 1}.pth") + torch.save({ + "model_state": minamo.state_dict() + }, f"result/wgan/minamo-{epoch + 1}.pth") + + print("Train ended.") + torch.save({ + "model_state": ginka.state_dict() + }, f"result/ginka.pth") + torch.save({ + "model_state": minamo.state_dict() + }, f"result/minamo.pth") + +if __name__ == "__main__": + torch.set_num_threads(4) + train() diff --git a/minamo/model/model.py b/minamo/model/model.py index ffb051e..f392953 100644 --- a/minamo/model/model.py +++ b/minamo/model/model.py @@ -1,7 +1,9 @@ +import torch import torch.nn as nn import torch.nn.functional as F from .vision import MinamoVisionModel from .topo import MinamoTopoModel +from shared.constant import VISION_WEIGHT, TOPO_WEIGHT class MinamoModel(nn.Module): def __init__(self, tile_types=32): @@ -16,3 +18,62 @@ class MinamoModel(nn.Module): topo_feat = self.topo_model(graph) return vision_feat, topo_feat + +class MinamoVisionScore(nn.Module): + def __init__(self, tile_types=32): + super().__init__() + # 视觉相似度部分 + self.vision_model = MinamoVisionModel(tile_types) + # 输出层 + self.vision_fc = nn.Sequential( + nn.Linear(512, 2048), + nn.LeakyReLU(0.2), + nn.Linear(2048, 1) + ) + + def forward(self, map): + vision_feat = self.vision_model(map) + vis_score = self.vision_fc(vision_feat) + return vis_score + +class MinamoTopoScore(nn.Module): + def __init__(self, tile_types=32): + super().__init__() + # 拓扑相似度部分 + self.topo_model = MinamoTopoModel(tile_types) + # 输出层 + self.topo_fc = nn.Sequential( + nn.Linear(512, 2048), + nn.LeakyReLU(0.2), + nn.Linear(2048, 1) + ) + + def forward(self, graph): + topo_feat = self.topo_model(graph) + topo_score = self.topo_fc(topo_feat) + return topo_score + +class MinamoScoreModule(nn.Module): + def __init__(self, tile_types=32): + super().__init__() + self.topo_model = MinamoTopoModel(tile_types) + self.vision_model = MinamoVisionModel(tile_types) + # 输出层 + self.topo_fc = nn.Sequential( + nn.Linear(512, 2048), + nn.LeakyReLU(0.2), + nn.Linear(2048, 1) + ) + self.vision_fc = nn.Sequential( + nn.Linear(512, 2048), + nn.LeakyReLU(0.2), + nn.Linear(2048, 1) + ) + + def forward(self, map, graph): + topo_feat = self.topo_model(graph) + topo_score = self.topo_fc(topo_feat) + vision_feat = self.vision_model(map) + vision_score = self.vision_fc(vision_feat) + score = VISION_WEIGHT * vision_score + TOPO_WEIGHT * topo_score + return score, vision_score, topo_score diff --git a/minamo/model/topo.py b/minamo/model/topo.py index 9e34ab1..bbf9d2d 100644 --- a/minamo/model/topo.py +++ b/minamo/model/topo.py @@ -1,56 +1,50 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch_geometric.nn import GATConv, global_max_pool +from torch.nn.utils import spectral_norm +from torch_geometric.nn import GATConv, global_max_pool, GCNConv, global_mean_pool from torch_geometric.data import Data class MinamoTopoModel(nn.Module): def __init__( - self, tile_types=32, emb_dim=128, hidden_dim=128, out_dim=512, mlp_dim=512 + self, tile_types=32, emb_dim=128, hidden_dim=256, out_dim=512, feat_dim=512 ): super().__init__() # 传入 softmax 概率值,直接映射 - self.input_proj = nn.Linear(tile_types, emb_dim) + self.input_proj = nn.Sequential( + nn.Linear(tile_types, emb_dim), + nn.LeakyReLU(0.2) + ) # 图卷积层 - self.conv1 = GATConv(emb_dim, hidden_dim*2, heads=8, dropout=0.2) - self.conv2 = GATConv(hidden_dim*16, hidden_dim*2, heads=8) - self.conv3 = GATConv(hidden_dim*16, hidden_dim*2, heads=8) - self.conv4 = GATConv(hidden_dim*16, out_dim, heads=1) + self.conv1 = GATConv(emb_dim, hidden_dim, heads=8) + self.conv2 = GATConv(hidden_dim*8, hidden_dim, heads=8) + self.conv3 = GATConv(hidden_dim*8, out_dim, heads=1) - # 正则化 - self.norm1 = nn.LayerNorm(hidden_dim*16) - self.norm2 = nn.LayerNorm(hidden_dim*16) - self.norm3 = nn.LayerNorm(hidden_dim*16) - self.norm4 = nn.LayerNorm(out_dim) + # self.norm1 = nn.LayerNorm(hidden_dim*8) + # self.norm2 = nn.LayerNorm(hidden_dim*8) + # self.norm3 = nn.LayerNorm(out_dim) - self.drop = nn.Dropout(0.3) - - # 增强MLP self.fc = nn.Sequential( - nn.Linear(out_dim, mlp_dim), + nn.Linear(out_dim, feat_dim), + nn.LeakyReLU(0.2) ) def forward(self, graph: Data): x = self.input_proj(graph.x) x = self.conv1(x, graph.edge_index) - x = F.relu(self.norm1(x)) + x = F.leaky_relu(x, 0.2) x = self.conv2(x, graph.edge_index) - x = F.relu(self.norm2(x)) + x = F.leaky_relu(x, 0.2) x = self.conv3(x, graph.edge_index) - x = F.relu(self.norm3(x)) - - x = self.conv4(x, graph.edge_index) - x = F.relu(self.norm4(x)) + x = F.leaky_relu(x, 0.2) # 池化 - x = self.drop(x) - x = global_max_pool(x, graph.batch) + x = global_mean_pool(x, graph.batch) topo_vec = self.fc(x) - # 归一化 - return F.normalize(topo_vec, p=2, dim=-1) + return topo_vec \ No newline at end of file diff --git a/minamo/model/vision.py b/minamo/model/vision.py index 415b272..0ba7a25 100644 --- a/minamo/model/vision.py +++ b/minamo/model/vision.py @@ -1,14 +1,28 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torchvision.models import resnet18 +from torch.nn.utils import spectral_norm class MinamoVisionModel(nn.Module): - def __init__(self, tile_types=32, out_dim=512): + def __init__(self, in_ch=32, out_dim=512): super().__init__() - self.resnet = resnet18(num_classes=out_dim) - self.resnet.conv1 = nn.Conv2d(tile_types, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.conv = nn.Sequential( + spectral_norm(nn.Conv2d(in_ch, in_ch*2, 3, stride=2)), # 6*6 + nn.LeakyReLU(0.2), + + spectral_norm(nn.Conv2d(in_ch*2, in_ch*4, 3)), #4*4 + nn.LeakyReLU(0.2), + + spectral_norm(nn.Conv2d(in_ch*4, in_ch*8, 3)), # 2*2 + nn.LeakyReLU(0.2), + + nn.Flatten() + ) + self.fc = nn.Sequential( + spectral_norm(nn.Linear(in_ch*8*2*2, out_dim)) + ) def forward(self, x): - vision_vec = self.resnet(x) - return F.normalize(vision_vec, p=2, dim=-1) # 归一化 + x = self.conv(x) + x = self.fc(x) + return x diff --git a/shared/constant.py b/shared/constant.py new file mode 100644 index 0000000..846a53c --- /dev/null +++ b/shared/constant.py @@ -0,0 +1,6 @@ +VIS_DIM = 512 +TOPO_DIM = 512 +FEAT_DIM = 1024 + +VISION_WEIGHT = 0 +TOPO_WEIGHT = 1 diff --git a/shared/graph.py b/shared/graph.py index 4ba5adc..e82ac7c 100644 --- a/shared/graph.py +++ b/shared/graph.py @@ -1,5 +1,6 @@ import torch from torch_geometric.data import Data, Batch +from torch_geometric.utils import add_self_loops def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data: """ @@ -44,6 +45,8 @@ def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data: src_feat = map_probs[:, edge_src // W, edge_src % W].T # [E, C] dst_feat = map_probs[:, edge_dst // W, edge_dst % W].T # [E, C] edge_attr = (src_feat + dst_feat) / 2 * edge_mask # [E, C] + + edge_index, edge_attr = add_self_loops(edge_index, edge_attr) return Data( x=node_features,