feat: 改进模型与相似度计算

This commit is contained in:
unanmed 2025-03-26 21:06:28 +08:00
parent def3cc5c3f
commit a5510ef211
10 changed files with 82 additions and 63 deletions

View File

@ -206,6 +206,7 @@ function generateSimilarData(id: string, map: number[][]) {
const id2 = `${id}.S${i}`;
const sid = `${id}:${id2}`;
const simi = compareMap(id, id2, map, clone);
res.push([
sid,
{

View File

@ -62,7 +62,7 @@ function weisfeilerLehmanIteration(
const compositeLabel = `${node.currentLabel}|${neighborLabels.join(
','
)}`.slice(0, 4096);
)}`.slice(0, 8192);
newLabels.push(compositeLabel);
});
@ -157,7 +157,7 @@ export function overallSimilarity(
const min = Math.min(ga.graph.size, gb.graph.size);
const iterations = Math.ceil(Math.max(1, Math.log(min)));
const similarity = wlKernel(ga, gb, iterations);
if (similarity > maxSimilarity) {
if (similarity > maxSimilarity && !isNaN(similarity)) {
maxSimilarity = similarity;
maxGraph = gb;
}
@ -171,5 +171,9 @@ export function overallSimilarity(
const reduction =
1 / (1 + Math.abs(a.unreachable.size - b.unreachable.size));
// 取根号使结果更接近线性
return Math.sqrt(totalSimilarity / graphsA.length) * reduction;
if (graphsA.length === 0) {
return 0;
} else {
return Math.sqrt(totalSimilarity / graphsA.length) * reduction;
}
}

41
ginka/model/input.py Normal file
View File

@ -0,0 +1,41 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualUpsampleBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.InstanceNorm2d(out_ch),
nn.GELU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.InstanceNorm2d(out_ch),
nn.GELU()
)
def forward(self, x):
return self.conv(x)
class GinkaInput(nn.Module):
def __init__(self, feat_dim=1024, out_ch=64):
super().__init__()
fc_dim = out_ch * 8 * 4 * 4
self.out_ch = out_ch
self.fc = nn.Sequential(
nn.Linear(feat_dim, fc_dim),
nn.BatchNorm1d(fc_dim),
nn.ReLU()
)
self.upsample = nn.Sequential(
ResidualUpsampleBlock(out_ch*8, out_ch*8),
ResidualUpsampleBlock(out_ch*8, out_ch*4),
ResidualUpsampleBlock(out_ch*4, out_ch)
)
def forward(self, x):
x = self.fc(x)
x = x.view(-1, self.out_ch*8, 4, 4)
x = self.upsample(x)
return x

View File

@ -2,49 +2,31 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .unet import GinkaUNet
from .sample import MapDownSample
from .input import GinkaInput
from .output import GinkaOutput
class GinkaModel(nn.Module):
def __init__(self, feat_dim=1024, base_ch=64, num_classes=32):
"""Ginka Model 模型定义部分
"""
super().__init__()
self.base_ch = base_ch
fc_dim = base_ch * 8 * 4 * 4
self.fc = nn.Sequential(
nn.Linear(feat_dim, fc_dim),
nn.BatchNorm1d(fc_dim),
nn.ReLU()
)
self.upsample = nn.Sequential(
nn.Conv2d(base_ch*8, base_ch*16, kernel_size=3, stride=1, padding=1),
nn.PixelShuffle(2),
nn.InstanceNorm2d(base_ch*4),
nn.ReLU(),
nn.Conv2d(base_ch*4, base_ch*8, kernel_size=3, stride=1, padding=1),
nn.PixelShuffle(2),
nn.InstanceNorm2d(base_ch*2),
nn.ReLU(),
nn.Conv2d(base_ch*2, base_ch*4, kernel_size=3, stride=1, padding=1),
nn.PixelShuffle(2),
nn.InstanceNorm2d(base_ch),
nn.ReLU(),
)
self.input = GinkaInput(feat_dim, base_ch)
self.unet = GinkaUNet(base_ch, num_classes)
self.down_sample = MapDownSample(num_classes, num_classes)
self.pool = nn.AdaptiveMaxPool2d((13, 13))
self.output = GinkaOutput(num_classes, (13, 13))
print(f"Input parameters: {sum(p.numel() for p in self.input.parameters())}")
print(f"UNet parameters: {sum(p.numel() for p in self.unet.parameters())}")
print(f"Output parameters: {sum(p.numel() for p in self.output.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in self.parameters())}")
def forward(self, feat):
def forward(self, x):
"""
Args:
feat: 参考地图的特征向量
Returns:
logits: 输出logits [BS, num_classes, H, W]
"""
x = self.fc(feat)
x = x.view(-1, self.base_ch*8, 4, 4)
x = self.upsample(x)
x = self.input(x)
x = self.unet(x)
x = F.interpolate(x, (13, 13), mode='bilinear')
x = self.output(x)
return x, F.softmax(x, dim=1)

10
ginka/model/output.py Normal file
View File

@ -0,0 +1,10 @@
import torch
import torch.nn as nn
class GinkaOutput(nn.Module):
def __init__(self, num_classes=32, out_size=(13, 13)):
super().__init__()
self.pool = nn.AdaptiveAvgPool2d(out_size)
def forward(self, x):
return self.pool(x)

View File

@ -1,15 +0,0 @@
import torch
import torch.nn as nn
class MapDownSample(nn.Module):
def __init__(self, in_ch=32, out_ch=32):
super().__init__()
self.down = nn.Sequential(
nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(in_ch, out_ch, 4, stride=1, padding=0)
)
def forward(self, x):
x = self.down(x)
return x

View File

@ -21,11 +21,11 @@ class GinkaEncoder(nn.Module):
elif block == 'SEBlock':
self.conv.append(SEBlock(out_channels))
self.conv.append(nn.GELU())
self.pool = nn.MaxPool2d(2)
self.down = nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=1)
def forward(self, x):
x_res = self.conv(x)
x_down = self.pool(x_res)
x_down = self.down(x_res)
return x_down, x_res
class GinkaDecoder(nn.Module):
@ -53,23 +53,24 @@ class GinkaDecoder(nn.Module):
return x
class GinkaBottleneck(nn.Module):
def __init__(self, in_channels, out_channels):
def __init__(self, in_channels, out_channels, attention=False):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.InstanceNorm2d(out_channels),
nn.GELU(),
SEBlock(out_channels),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.InstanceNorm2d(out_channels),
nn.GELU(),
)
if attention:
self.conv.append(SEBlock(out_channels))
self.conv.append(nn.GELU())
def forward(self, x):
return self.conv(x)
class GinkaUNet(nn.Module):
def __init__(self, in_ch=32, out_ch=32):
def __init__(self, in_ch=64, out_ch=32):
"""Ginka Model UNet 部分
"""
super().__init__()
@ -78,7 +79,7 @@ class GinkaUNet(nn.Module):
self.down3 = GinkaEncoder(in_ch*4, in_ch*8, attention=True, block='SEBlock')
self.down4 = GinkaEncoder(in_ch*8, in_ch*16, attention=True, block='SEBlock')
self.bottleneck = GinkaBottleneck(in_ch*16, in_ch*16)
self.bottleneck = GinkaBottleneck(in_ch*16, in_ch*16, attention=True)
self.up1 = GinkaDecoder(in_ch*16, in_ch*8, attention=True, block='SEBlock')
self.up2 = GinkaDecoder(in_ch*8, in_ch*4, attention=True, block='SEBlock')
@ -97,9 +98,9 @@ class GinkaUNet(nn.Module):
x = self.bottleneck(x_down4)
x = self.up1(x, skip4) # 用 down2 的 skip
x = self.up2(x, skip3) # 用 down2 的 skip
x = self.up3(x, skip2) # 用 down1 的 skip
x = self.up4(x, skip1) # 用 down1 的 skip
x = self.up1(x, skip4)
x = self.up2(x, skip3)
x = self.up3(x, skip2)
x = self.up4(x, skip1)
return self.final(x)

View File

@ -91,8 +91,6 @@ def train():
graph1 = graph1.to(device)
graph2 = graph2.to(device)
# print(map1.shape, map2.shape)
# 前向传播
optimizer.zero_grad()
vision_feat1, topo_feat1 = model(map1, graph1)

View File

@ -44,8 +44,6 @@ def validate():
vision_feat1, topo_feat1 = model(map1_val, graph1)
vision_feat2, topo_feat2 = model(map2_val, graph2)
print(vision_feat1.isnan().any().item(), topo_feat1.isnan().any().item(), vision_feat2.isnan().any().item(), topo_feat2.isnan().any().item())
vision_pred_val = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
topo_pred_val = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
loss_val = criterion(

View File

@ -1,7 +1,6 @@
import torch
from torch_geometric.data import Data, Batch
def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data:
"""
可导的图结构转换返回 PyG Data 对象