diff --git a/data/src/minamo.ts b/data/src/minamo.ts index dd6578e..f15218c 100644 --- a/data/src/minamo.ts +++ b/data/src/minamo.ts @@ -206,6 +206,7 @@ function generateSimilarData(id: string, map: number[][]) { const id2 = `${id}.S${i}`; const sid = `${id}:${id2}`; const simi = compareMap(id, id2, map, clone); + res.push([ sid, { diff --git a/data/src/topology/similarity.ts b/data/src/topology/similarity.ts index aa2799d..9275d3c 100644 --- a/data/src/topology/similarity.ts +++ b/data/src/topology/similarity.ts @@ -62,7 +62,7 @@ function weisfeilerLehmanIteration( const compositeLabel = `${node.currentLabel}|${neighborLabels.join( ',' - )}`.slice(0, 4096); + )}`.slice(0, 8192); newLabels.push(compositeLabel); }); @@ -157,7 +157,7 @@ export function overallSimilarity( const min = Math.min(ga.graph.size, gb.graph.size); const iterations = Math.ceil(Math.max(1, Math.log(min))); const similarity = wlKernel(ga, gb, iterations); - if (similarity > maxSimilarity) { + if (similarity > maxSimilarity && !isNaN(similarity)) { maxSimilarity = similarity; maxGraph = gb; } @@ -171,5 +171,9 @@ export function overallSimilarity( const reduction = 1 / (1 + Math.abs(a.unreachable.size - b.unreachable.size)); // 取根号使结果更接近线性 - return Math.sqrt(totalSimilarity / graphsA.length) * reduction; + if (graphsA.length === 0) { + return 0; + } else { + return Math.sqrt(totalSimilarity / graphsA.length) * reduction; + } } diff --git a/ginka/model/input.py b/ginka/model/input.py new file mode 100644 index 0000000..5348a2b --- /dev/null +++ b/ginka/model/input.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class ResidualUpsampleBlock(nn.Module): + def __init__(self, in_ch, out_ch): + super().__init__() + self.conv = nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), + nn.Conv2d(in_ch, out_ch, 3, padding=1), + nn.InstanceNorm2d(out_ch), + nn.GELU(), + nn.Conv2d(out_ch, out_ch, 3, padding=1), + nn.InstanceNorm2d(out_ch), + nn.GELU() + ) + + def forward(self, x): + return self.conv(x) + +class GinkaInput(nn.Module): + def __init__(self, feat_dim=1024, out_ch=64): + super().__init__() + fc_dim = out_ch * 8 * 4 * 4 + self.out_ch = out_ch + self.fc = nn.Sequential( + nn.Linear(feat_dim, fc_dim), + nn.BatchNorm1d(fc_dim), + nn.ReLU() + ) + self.upsample = nn.Sequential( + ResidualUpsampleBlock(out_ch*8, out_ch*8), + ResidualUpsampleBlock(out_ch*8, out_ch*4), + ResidualUpsampleBlock(out_ch*4, out_ch) + ) + + def forward(self, x): + x = self.fc(x) + x = x.view(-1, self.out_ch*8, 4, 4) + x = self.upsample(x) + return x diff --git a/ginka/model/model.py b/ginka/model/model.py index 334cd71..4cdb830 100644 --- a/ginka/model/model.py +++ b/ginka/model/model.py @@ -2,49 +2,31 @@ import torch import torch.nn as nn import torch.nn.functional as F from .unet import GinkaUNet -from .sample import MapDownSample +from .input import GinkaInput +from .output import GinkaOutput class GinkaModel(nn.Module): def __init__(self, feat_dim=1024, base_ch=64, num_classes=32): """Ginka Model 模型定义部分 """ super().__init__() - self.base_ch = base_ch - fc_dim = base_ch * 8 * 4 * 4 - self.fc = nn.Sequential( - nn.Linear(feat_dim, fc_dim), - nn.BatchNorm1d(fc_dim), - nn.ReLU() - ) - self.upsample = nn.Sequential( - nn.Conv2d(base_ch*8, base_ch*16, kernel_size=3, stride=1, padding=1), - nn.PixelShuffle(2), - nn.InstanceNorm2d(base_ch*4), - nn.ReLU(), - nn.Conv2d(base_ch*4, base_ch*8, kernel_size=3, stride=1, padding=1), - nn.PixelShuffle(2), - nn.InstanceNorm2d(base_ch*2), - nn.ReLU(), - nn.Conv2d(base_ch*2, base_ch*4, kernel_size=3, stride=1, padding=1), - nn.PixelShuffle(2), - nn.InstanceNorm2d(base_ch), - nn.ReLU(), - ) + self.input = GinkaInput(feat_dim, base_ch) self.unet = GinkaUNet(base_ch, num_classes) - self.down_sample = MapDownSample(num_classes, num_classes) - self.pool = nn.AdaptiveMaxPool2d((13, 13)) + self.output = GinkaOutput(num_classes, (13, 13)) + print(f"Input parameters: {sum(p.numel() for p in self.input.parameters())}") + print(f"UNet parameters: {sum(p.numel() for p in self.unet.parameters())}") + print(f"Output parameters: {sum(p.numel() for p in self.output.parameters())}") + print(f"Total parameters: {sum(p.numel() for p in self.parameters())}") - def forward(self, feat): + def forward(self, x): """ Args: feat: 参考地图的特征向量 Returns: logits: 输出logits [BS, num_classes, H, W] """ - x = self.fc(feat) - x = x.view(-1, self.base_ch*8, 4, 4) - x = self.upsample(x) + x = self.input(x) x = self.unet(x) - x = F.interpolate(x, (13, 13), mode='bilinear') + x = self.output(x) return x, F.softmax(x, dim=1) \ No newline at end of file diff --git a/ginka/model/output.py b/ginka/model/output.py new file mode 100644 index 0000000..89989ab --- /dev/null +++ b/ginka/model/output.py @@ -0,0 +1,10 @@ +import torch +import torch.nn as nn + +class GinkaOutput(nn.Module): + def __init__(self, num_classes=32, out_size=(13, 13)): + super().__init__() + self.pool = nn.AdaptiveAvgPool2d(out_size) + + def forward(self, x): + return self.pool(x) diff --git a/ginka/model/sample.py b/ginka/model/sample.py deleted file mode 100644 index 1dd5193..0000000 --- a/ginka/model/sample.py +++ /dev/null @@ -1,15 +0,0 @@ -import torch -import torch.nn as nn - -class MapDownSample(nn.Module): - def __init__(self, in_ch=32, out_ch=32): - super().__init__() - self.down = nn.Sequential( - nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1), - nn.ReLU(), - nn.Conv2d(in_ch, out_ch, 4, stride=1, padding=0) - ) - - def forward(self, x): - x = self.down(x) - return x diff --git a/ginka/model/unet.py b/ginka/model/unet.py index 5a40ee9..9303e51 100644 --- a/ginka/model/unet.py +++ b/ginka/model/unet.py @@ -21,11 +21,11 @@ class GinkaEncoder(nn.Module): elif block == 'SEBlock': self.conv.append(SEBlock(out_channels)) self.conv.append(nn.GELU()) - self.pool = nn.MaxPool2d(2) + self.down = nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=1) def forward(self, x): x_res = self.conv(x) - x_down = self.pool(x_res) + x_down = self.down(x_res) return x_down, x_res class GinkaDecoder(nn.Module): @@ -53,23 +53,24 @@ class GinkaDecoder(nn.Module): return x class GinkaBottleneck(nn.Module): - def __init__(self, in_channels, out_channels): + def __init__(self, in_channels, out_channels, attention=False): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.InstanceNorm2d(out_channels), nn.GELU(), - SEBlock(out_channels), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.InstanceNorm2d(out_channels), - nn.GELU(), ) + if attention: + self.conv.append(SEBlock(out_channels)) + self.conv.append(nn.GELU()) def forward(self, x): return self.conv(x) class GinkaUNet(nn.Module): - def __init__(self, in_ch=32, out_ch=32): + def __init__(self, in_ch=64, out_ch=32): """Ginka Model UNet 部分 """ super().__init__() @@ -78,7 +79,7 @@ class GinkaUNet(nn.Module): self.down3 = GinkaEncoder(in_ch*4, in_ch*8, attention=True, block='SEBlock') self.down4 = GinkaEncoder(in_ch*8, in_ch*16, attention=True, block='SEBlock') - self.bottleneck = GinkaBottleneck(in_ch*16, in_ch*16) + self.bottleneck = GinkaBottleneck(in_ch*16, in_ch*16, attention=True) self.up1 = GinkaDecoder(in_ch*16, in_ch*8, attention=True, block='SEBlock') self.up2 = GinkaDecoder(in_ch*8, in_ch*4, attention=True, block='SEBlock') @@ -97,9 +98,9 @@ class GinkaUNet(nn.Module): x = self.bottleneck(x_down4) - x = self.up1(x, skip4) # 用 down2 的 skip - x = self.up2(x, skip3) # 用 down2 的 skip - x = self.up3(x, skip2) # 用 down1 的 skip - x = self.up4(x, skip1) # 用 down1 的 skip + x = self.up1(x, skip4) + x = self.up2(x, skip3) + x = self.up3(x, skip2) + x = self.up4(x, skip1) return self.final(x) diff --git a/minamo/train.py b/minamo/train.py index 24f2dcd..ef81db6 100644 --- a/minamo/train.py +++ b/minamo/train.py @@ -91,8 +91,6 @@ def train(): graph1 = graph1.to(device) graph2 = graph2.to(device) - # print(map1.shape, map2.shape) - # 前向传播 optimizer.zero_grad() vision_feat1, topo_feat1 = model(map1, graph1) diff --git a/minamo/validate.py b/minamo/validate.py index 8eae08f..d9635c8 100644 --- a/minamo/validate.py +++ b/minamo/validate.py @@ -44,8 +44,6 @@ def validate(): vision_feat1, topo_feat1 = model(map1_val, graph1) vision_feat2, topo_feat2 = model(map2_val, graph2) - print(vision_feat1.isnan().any().item(), topo_feat1.isnan().any().item(), vision_feat2.isnan().any().item(), topo_feat2.isnan().any().item()) - vision_pred_val = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1) topo_pred_val = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1) loss_val = criterion( diff --git a/shared/graph.py b/shared/graph.py index bfefa8c..c29eb9c 100644 --- a/shared/graph.py +++ b/shared/graph.py @@ -1,7 +1,6 @@ import torch from torch_geometric.data import Data, Batch - def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data: """ 可导的图结构转换(返回 PyG Data 对象)