feat: 改进模型与相似度计算

This commit is contained in:
unanmed 2025-03-26 21:06:28 +08:00
parent def3cc5c3f
commit a5510ef211
10 changed files with 82 additions and 63 deletions

View File

@ -206,6 +206,7 @@ function generateSimilarData(id: string, map: number[][]) {
const id2 = `${id}.S${i}`; const id2 = `${id}.S${i}`;
const sid = `${id}:${id2}`; const sid = `${id}:${id2}`;
const simi = compareMap(id, id2, map, clone); const simi = compareMap(id, id2, map, clone);
res.push([ res.push([
sid, sid,
{ {

View File

@ -62,7 +62,7 @@ function weisfeilerLehmanIteration(
const compositeLabel = `${node.currentLabel}|${neighborLabels.join( const compositeLabel = `${node.currentLabel}|${neighborLabels.join(
',' ','
)}`.slice(0, 4096); )}`.slice(0, 8192);
newLabels.push(compositeLabel); newLabels.push(compositeLabel);
}); });
@ -157,7 +157,7 @@ export function overallSimilarity(
const min = Math.min(ga.graph.size, gb.graph.size); const min = Math.min(ga.graph.size, gb.graph.size);
const iterations = Math.ceil(Math.max(1, Math.log(min))); const iterations = Math.ceil(Math.max(1, Math.log(min)));
const similarity = wlKernel(ga, gb, iterations); const similarity = wlKernel(ga, gb, iterations);
if (similarity > maxSimilarity) { if (similarity > maxSimilarity && !isNaN(similarity)) {
maxSimilarity = similarity; maxSimilarity = similarity;
maxGraph = gb; maxGraph = gb;
} }
@ -171,5 +171,9 @@ export function overallSimilarity(
const reduction = const reduction =
1 / (1 + Math.abs(a.unreachable.size - b.unreachable.size)); 1 / (1 + Math.abs(a.unreachable.size - b.unreachable.size));
// 取根号使结果更接近线性 // 取根号使结果更接近线性
return Math.sqrt(totalSimilarity / graphsA.length) * reduction; if (graphsA.length === 0) {
return 0;
} else {
return Math.sqrt(totalSimilarity / graphsA.length) * reduction;
}
} }

41
ginka/model/input.py Normal file
View File

@ -0,0 +1,41 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualUpsampleBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.InstanceNorm2d(out_ch),
nn.GELU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.InstanceNorm2d(out_ch),
nn.GELU()
)
def forward(self, x):
return self.conv(x)
class GinkaInput(nn.Module):
def __init__(self, feat_dim=1024, out_ch=64):
super().__init__()
fc_dim = out_ch * 8 * 4 * 4
self.out_ch = out_ch
self.fc = nn.Sequential(
nn.Linear(feat_dim, fc_dim),
nn.BatchNorm1d(fc_dim),
nn.ReLU()
)
self.upsample = nn.Sequential(
ResidualUpsampleBlock(out_ch*8, out_ch*8),
ResidualUpsampleBlock(out_ch*8, out_ch*4),
ResidualUpsampleBlock(out_ch*4, out_ch)
)
def forward(self, x):
x = self.fc(x)
x = x.view(-1, self.out_ch*8, 4, 4)
x = self.upsample(x)
return x

View File

@ -2,49 +2,31 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .unet import GinkaUNet from .unet import GinkaUNet
from .sample import MapDownSample from .input import GinkaInput
from .output import GinkaOutput
class GinkaModel(nn.Module): class GinkaModel(nn.Module):
def __init__(self, feat_dim=1024, base_ch=64, num_classes=32): def __init__(self, feat_dim=1024, base_ch=64, num_classes=32):
"""Ginka Model 模型定义部分 """Ginka Model 模型定义部分
""" """
super().__init__() super().__init__()
self.base_ch = base_ch self.input = GinkaInput(feat_dim, base_ch)
fc_dim = base_ch * 8 * 4 * 4
self.fc = nn.Sequential(
nn.Linear(feat_dim, fc_dim),
nn.BatchNorm1d(fc_dim),
nn.ReLU()
)
self.upsample = nn.Sequential(
nn.Conv2d(base_ch*8, base_ch*16, kernel_size=3, stride=1, padding=1),
nn.PixelShuffle(2),
nn.InstanceNorm2d(base_ch*4),
nn.ReLU(),
nn.Conv2d(base_ch*4, base_ch*8, kernel_size=3, stride=1, padding=1),
nn.PixelShuffle(2),
nn.InstanceNorm2d(base_ch*2),
nn.ReLU(),
nn.Conv2d(base_ch*2, base_ch*4, kernel_size=3, stride=1, padding=1),
nn.PixelShuffle(2),
nn.InstanceNorm2d(base_ch),
nn.ReLU(),
)
self.unet = GinkaUNet(base_ch, num_classes) self.unet = GinkaUNet(base_ch, num_classes)
self.down_sample = MapDownSample(num_classes, num_classes) self.output = GinkaOutput(num_classes, (13, 13))
self.pool = nn.AdaptiveMaxPool2d((13, 13)) print(f"Input parameters: {sum(p.numel() for p in self.input.parameters())}")
print(f"UNet parameters: {sum(p.numel() for p in self.unet.parameters())}")
print(f"Output parameters: {sum(p.numel() for p in self.output.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in self.parameters())}")
def forward(self, feat): def forward(self, x):
""" """
Args: Args:
feat: 参考地图的特征向量 feat: 参考地图的特征向量
Returns: Returns:
logits: 输出logits [BS, num_classes, H, W] logits: 输出logits [BS, num_classes, H, W]
""" """
x = self.fc(feat) x = self.input(x)
x = x.view(-1, self.base_ch*8, 4, 4)
x = self.upsample(x)
x = self.unet(x) x = self.unet(x)
x = F.interpolate(x, (13, 13), mode='bilinear') x = self.output(x)
return x, F.softmax(x, dim=1) return x, F.softmax(x, dim=1)

10
ginka/model/output.py Normal file
View File

@ -0,0 +1,10 @@
import torch
import torch.nn as nn
class GinkaOutput(nn.Module):
def __init__(self, num_classes=32, out_size=(13, 13)):
super().__init__()
self.pool = nn.AdaptiveAvgPool2d(out_size)
def forward(self, x):
return self.pool(x)

View File

@ -1,15 +0,0 @@
import torch
import torch.nn as nn
class MapDownSample(nn.Module):
def __init__(self, in_ch=32, out_ch=32):
super().__init__()
self.down = nn.Sequential(
nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(in_ch, out_ch, 4, stride=1, padding=0)
)
def forward(self, x):
x = self.down(x)
return x

View File

@ -21,11 +21,11 @@ class GinkaEncoder(nn.Module):
elif block == 'SEBlock': elif block == 'SEBlock':
self.conv.append(SEBlock(out_channels)) self.conv.append(SEBlock(out_channels))
self.conv.append(nn.GELU()) self.conv.append(nn.GELU())
self.pool = nn.MaxPool2d(2) self.down = nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=1)
def forward(self, x): def forward(self, x):
x_res = self.conv(x) x_res = self.conv(x)
x_down = self.pool(x_res) x_down = self.down(x_res)
return x_down, x_res return x_down, x_res
class GinkaDecoder(nn.Module): class GinkaDecoder(nn.Module):
@ -53,23 +53,24 @@ class GinkaDecoder(nn.Module):
return x return x
class GinkaBottleneck(nn.Module): class GinkaBottleneck(nn.Module):
def __init__(self, in_channels, out_channels): def __init__(self, in_channels, out_channels, attention=False):
super().__init__() super().__init__()
self.conv = nn.Sequential( self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.InstanceNorm2d(out_channels), nn.InstanceNorm2d(out_channels),
nn.GELU(), nn.GELU(),
SEBlock(out_channels),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.InstanceNorm2d(out_channels), nn.InstanceNorm2d(out_channels),
nn.GELU(),
) )
if attention:
self.conv.append(SEBlock(out_channels))
self.conv.append(nn.GELU())
def forward(self, x): def forward(self, x):
return self.conv(x) return self.conv(x)
class GinkaUNet(nn.Module): class GinkaUNet(nn.Module):
def __init__(self, in_ch=32, out_ch=32): def __init__(self, in_ch=64, out_ch=32):
"""Ginka Model UNet 部分 """Ginka Model UNet 部分
""" """
super().__init__() super().__init__()
@ -78,7 +79,7 @@ class GinkaUNet(nn.Module):
self.down3 = GinkaEncoder(in_ch*4, in_ch*8, attention=True, block='SEBlock') self.down3 = GinkaEncoder(in_ch*4, in_ch*8, attention=True, block='SEBlock')
self.down4 = GinkaEncoder(in_ch*8, in_ch*16, attention=True, block='SEBlock') self.down4 = GinkaEncoder(in_ch*8, in_ch*16, attention=True, block='SEBlock')
self.bottleneck = GinkaBottleneck(in_ch*16, in_ch*16) self.bottleneck = GinkaBottleneck(in_ch*16, in_ch*16, attention=True)
self.up1 = GinkaDecoder(in_ch*16, in_ch*8, attention=True, block='SEBlock') self.up1 = GinkaDecoder(in_ch*16, in_ch*8, attention=True, block='SEBlock')
self.up2 = GinkaDecoder(in_ch*8, in_ch*4, attention=True, block='SEBlock') self.up2 = GinkaDecoder(in_ch*8, in_ch*4, attention=True, block='SEBlock')
@ -97,9 +98,9 @@ class GinkaUNet(nn.Module):
x = self.bottleneck(x_down4) x = self.bottleneck(x_down4)
x = self.up1(x, skip4) # 用 down2 的 skip x = self.up1(x, skip4)
x = self.up2(x, skip3) # 用 down2 的 skip x = self.up2(x, skip3)
x = self.up3(x, skip2) # 用 down1 的 skip x = self.up3(x, skip2)
x = self.up4(x, skip1) # 用 down1 的 skip x = self.up4(x, skip1)
return self.final(x) return self.final(x)

View File

@ -91,8 +91,6 @@ def train():
graph1 = graph1.to(device) graph1 = graph1.to(device)
graph2 = graph2.to(device) graph2 = graph2.to(device)
# print(map1.shape, map2.shape)
# 前向传播 # 前向传播
optimizer.zero_grad() optimizer.zero_grad()
vision_feat1, topo_feat1 = model(map1, graph1) vision_feat1, topo_feat1 = model(map1, graph1)

View File

@ -44,8 +44,6 @@ def validate():
vision_feat1, topo_feat1 = model(map1_val, graph1) vision_feat1, topo_feat1 = model(map1_val, graph1)
vision_feat2, topo_feat2 = model(map2_val, graph2) vision_feat2, topo_feat2 = model(map2_val, graph2)
print(vision_feat1.isnan().any().item(), topo_feat1.isnan().any().item(), vision_feat2.isnan().any().item(), topo_feat2.isnan().any().item())
vision_pred_val = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1) vision_pred_val = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
topo_pred_val = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1) topo_pred_val = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
loss_val = criterion( loss_val = criterion(

View File

@ -1,7 +1,6 @@
import torch import torch
from torch_geometric.data import Data, Batch from torch_geometric.data import Data, Batch
def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data: def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data:
""" """
可导的图结构转换返回 PyG Data 对象 可导的图结构转换返回 PyG Data 对象