mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-17 23:21:20 +08:00
feat: 改进模型与相似度计算
This commit is contained in:
parent
def3cc5c3f
commit
a5510ef211
@ -206,6 +206,7 @@ function generateSimilarData(id: string, map: number[][]) {
|
||||
const id2 = `${id}.S${i}`;
|
||||
const sid = `${id}:${id2}`;
|
||||
const simi = compareMap(id, id2, map, clone);
|
||||
|
||||
res.push([
|
||||
sid,
|
||||
{
|
||||
|
||||
@ -62,7 +62,7 @@ function weisfeilerLehmanIteration(
|
||||
|
||||
const compositeLabel = `${node.currentLabel}|${neighborLabels.join(
|
||||
','
|
||||
)}`.slice(0, 4096);
|
||||
)}`.slice(0, 8192);
|
||||
|
||||
newLabels.push(compositeLabel);
|
||||
});
|
||||
@ -157,7 +157,7 @@ export function overallSimilarity(
|
||||
const min = Math.min(ga.graph.size, gb.graph.size);
|
||||
const iterations = Math.ceil(Math.max(1, Math.log(min)));
|
||||
const similarity = wlKernel(ga, gb, iterations);
|
||||
if (similarity > maxSimilarity) {
|
||||
if (similarity > maxSimilarity && !isNaN(similarity)) {
|
||||
maxSimilarity = similarity;
|
||||
maxGraph = gb;
|
||||
}
|
||||
@ -171,5 +171,9 @@ export function overallSimilarity(
|
||||
const reduction =
|
||||
1 / (1 + Math.abs(a.unreachable.size - b.unreachable.size));
|
||||
// 取根号使结果更接近线性
|
||||
return Math.sqrt(totalSimilarity / graphsA.length) * reduction;
|
||||
if (graphsA.length === 0) {
|
||||
return 0;
|
||||
} else {
|
||||
return Math.sqrt(totalSimilarity / graphsA.length) * reduction;
|
||||
}
|
||||
}
|
||||
|
||||
41
ginka/model/input.py
Normal file
41
ginka/model/input.py
Normal file
@ -0,0 +1,41 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class ResidualUpsampleBlock(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
||||
nn.Conv2d(in_ch, out_ch, 3, padding=1),
|
||||
nn.InstanceNorm2d(out_ch),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(out_ch, out_ch, 3, padding=1),
|
||||
nn.InstanceNorm2d(out_ch),
|
||||
nn.GELU()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
class GinkaInput(nn.Module):
|
||||
def __init__(self, feat_dim=1024, out_ch=64):
|
||||
super().__init__()
|
||||
fc_dim = out_ch * 8 * 4 * 4
|
||||
self.out_ch = out_ch
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(feat_dim, fc_dim),
|
||||
nn.BatchNorm1d(fc_dim),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.upsample = nn.Sequential(
|
||||
ResidualUpsampleBlock(out_ch*8, out_ch*8),
|
||||
ResidualUpsampleBlock(out_ch*8, out_ch*4),
|
||||
ResidualUpsampleBlock(out_ch*4, out_ch)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc(x)
|
||||
x = x.view(-1, self.out_ch*8, 4, 4)
|
||||
x = self.upsample(x)
|
||||
return x
|
||||
@ -2,49 +2,31 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .unet import GinkaUNet
|
||||
from .sample import MapDownSample
|
||||
from .input import GinkaInput
|
||||
from .output import GinkaOutput
|
||||
|
||||
class GinkaModel(nn.Module):
|
||||
def __init__(self, feat_dim=1024, base_ch=64, num_classes=32):
|
||||
"""Ginka Model 模型定义部分
|
||||
"""
|
||||
super().__init__()
|
||||
self.base_ch = base_ch
|
||||
fc_dim = base_ch * 8 * 4 * 4
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(feat_dim, fc_dim),
|
||||
nn.BatchNorm1d(fc_dim),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.upsample = nn.Sequential(
|
||||
nn.Conv2d(base_ch*8, base_ch*16, kernel_size=3, stride=1, padding=1),
|
||||
nn.PixelShuffle(2),
|
||||
nn.InstanceNorm2d(base_ch*4),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(base_ch*4, base_ch*8, kernel_size=3, stride=1, padding=1),
|
||||
nn.PixelShuffle(2),
|
||||
nn.InstanceNorm2d(base_ch*2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(base_ch*2, base_ch*4, kernel_size=3, stride=1, padding=1),
|
||||
nn.PixelShuffle(2),
|
||||
nn.InstanceNorm2d(base_ch),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.input = GinkaInput(feat_dim, base_ch)
|
||||
self.unet = GinkaUNet(base_ch, num_classes)
|
||||
self.down_sample = MapDownSample(num_classes, num_classes)
|
||||
self.pool = nn.AdaptiveMaxPool2d((13, 13))
|
||||
self.output = GinkaOutput(num_classes, (13, 13))
|
||||
print(f"Input parameters: {sum(p.numel() for p in self.input.parameters())}")
|
||||
print(f"UNet parameters: {sum(p.numel() for p in self.unet.parameters())}")
|
||||
print(f"Output parameters: {sum(p.numel() for p in self.output.parameters())}")
|
||||
print(f"Total parameters: {sum(p.numel() for p in self.parameters())}")
|
||||
|
||||
def forward(self, feat):
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
feat: 参考地图的特征向量
|
||||
Returns:
|
||||
logits: 输出logits [BS, num_classes, H, W]
|
||||
"""
|
||||
x = self.fc(feat)
|
||||
x = x.view(-1, self.base_ch*8, 4, 4)
|
||||
x = self.upsample(x)
|
||||
x = self.input(x)
|
||||
x = self.unet(x)
|
||||
x = F.interpolate(x, (13, 13), mode='bilinear')
|
||||
x = self.output(x)
|
||||
return x, F.softmax(x, dim=1)
|
||||
|
||||
10
ginka/model/output.py
Normal file
10
ginka/model/output.py
Normal file
@ -0,0 +1,10 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class GinkaOutput(nn.Module):
|
||||
def __init__(self, num_classes=32, out_size=(13, 13)):
|
||||
super().__init__()
|
||||
self.pool = nn.AdaptiveAvgPool2d(out_size)
|
||||
|
||||
def forward(self, x):
|
||||
return self.pool(x)
|
||||
@ -1,15 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class MapDownSample(nn.Module):
|
||||
def __init__(self, in_ch=32, out_ch=32):
|
||||
super().__init__()
|
||||
self.down = nn.Sequential(
|
||||
nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_ch, out_ch, 4, stride=1, padding=0)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.down(x)
|
||||
return x
|
||||
@ -21,11 +21,11 @@ class GinkaEncoder(nn.Module):
|
||||
elif block == 'SEBlock':
|
||||
self.conv.append(SEBlock(out_channels))
|
||||
self.conv.append(nn.GELU())
|
||||
self.pool = nn.MaxPool2d(2)
|
||||
self.down = nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x_res = self.conv(x)
|
||||
x_down = self.pool(x_res)
|
||||
x_down = self.down(x_res)
|
||||
return x_down, x_res
|
||||
|
||||
class GinkaDecoder(nn.Module):
|
||||
@ -53,23 +53,24 @@ class GinkaDecoder(nn.Module):
|
||||
return x
|
||||
|
||||
class GinkaBottleneck(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
def __init__(self, in_channels, out_channels, attention=False):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.InstanceNorm2d(out_channels),
|
||||
nn.GELU(),
|
||||
SEBlock(out_channels),
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.InstanceNorm2d(out_channels),
|
||||
nn.GELU(),
|
||||
)
|
||||
if attention:
|
||||
self.conv.append(SEBlock(out_channels))
|
||||
self.conv.append(nn.GELU())
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
class GinkaUNet(nn.Module):
|
||||
def __init__(self, in_ch=32, out_ch=32):
|
||||
def __init__(self, in_ch=64, out_ch=32):
|
||||
"""Ginka Model UNet 部分
|
||||
"""
|
||||
super().__init__()
|
||||
@ -78,7 +79,7 @@ class GinkaUNet(nn.Module):
|
||||
self.down3 = GinkaEncoder(in_ch*4, in_ch*8, attention=True, block='SEBlock')
|
||||
self.down4 = GinkaEncoder(in_ch*8, in_ch*16, attention=True, block='SEBlock')
|
||||
|
||||
self.bottleneck = GinkaBottleneck(in_ch*16, in_ch*16)
|
||||
self.bottleneck = GinkaBottleneck(in_ch*16, in_ch*16, attention=True)
|
||||
|
||||
self.up1 = GinkaDecoder(in_ch*16, in_ch*8, attention=True, block='SEBlock')
|
||||
self.up2 = GinkaDecoder(in_ch*8, in_ch*4, attention=True, block='SEBlock')
|
||||
@ -97,9 +98,9 @@ class GinkaUNet(nn.Module):
|
||||
|
||||
x = self.bottleneck(x_down4)
|
||||
|
||||
x = self.up1(x, skip4) # 用 down2 的 skip
|
||||
x = self.up2(x, skip3) # 用 down2 的 skip
|
||||
x = self.up3(x, skip2) # 用 down1 的 skip
|
||||
x = self.up4(x, skip1) # 用 down1 的 skip
|
||||
x = self.up1(x, skip4)
|
||||
x = self.up2(x, skip3)
|
||||
x = self.up3(x, skip2)
|
||||
x = self.up4(x, skip1)
|
||||
|
||||
return self.final(x)
|
||||
|
||||
@ -91,8 +91,6 @@ def train():
|
||||
graph1 = graph1.to(device)
|
||||
graph2 = graph2.to(device)
|
||||
|
||||
# print(map1.shape, map2.shape)
|
||||
|
||||
# 前向传播
|
||||
optimizer.zero_grad()
|
||||
vision_feat1, topo_feat1 = model(map1, graph1)
|
||||
|
||||
@ -44,8 +44,6 @@ def validate():
|
||||
vision_feat1, topo_feat1 = model(map1_val, graph1)
|
||||
vision_feat2, topo_feat2 = model(map2_val, graph2)
|
||||
|
||||
print(vision_feat1.isnan().any().item(), topo_feat1.isnan().any().item(), vision_feat2.isnan().any().item(), topo_feat2.isnan().any().item())
|
||||
|
||||
vision_pred_val = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
||||
topo_pred_val = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||
loss_val = criterion(
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import torch
|
||||
from torch_geometric.data import Data, Batch
|
||||
|
||||
|
||||
def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data:
|
||||
"""
|
||||
可导的图结构转换(返回 PyG Data 对象)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user