mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-22 19:15:23 +08:00
feat: 改进模型与相似度计算
This commit is contained in:
parent
def3cc5c3f
commit
a5510ef211
@ -206,6 +206,7 @@ function generateSimilarData(id: string, map: number[][]) {
|
|||||||
const id2 = `${id}.S${i}`;
|
const id2 = `${id}.S${i}`;
|
||||||
const sid = `${id}:${id2}`;
|
const sid = `${id}:${id2}`;
|
||||||
const simi = compareMap(id, id2, map, clone);
|
const simi = compareMap(id, id2, map, clone);
|
||||||
|
|
||||||
res.push([
|
res.push([
|
||||||
sid,
|
sid,
|
||||||
{
|
{
|
||||||
|
|||||||
@ -62,7 +62,7 @@ function weisfeilerLehmanIteration(
|
|||||||
|
|
||||||
const compositeLabel = `${node.currentLabel}|${neighborLabels.join(
|
const compositeLabel = `${node.currentLabel}|${neighborLabels.join(
|
||||||
','
|
','
|
||||||
)}`.slice(0, 4096);
|
)}`.slice(0, 8192);
|
||||||
|
|
||||||
newLabels.push(compositeLabel);
|
newLabels.push(compositeLabel);
|
||||||
});
|
});
|
||||||
@ -157,7 +157,7 @@ export function overallSimilarity(
|
|||||||
const min = Math.min(ga.graph.size, gb.graph.size);
|
const min = Math.min(ga.graph.size, gb.graph.size);
|
||||||
const iterations = Math.ceil(Math.max(1, Math.log(min)));
|
const iterations = Math.ceil(Math.max(1, Math.log(min)));
|
||||||
const similarity = wlKernel(ga, gb, iterations);
|
const similarity = wlKernel(ga, gb, iterations);
|
||||||
if (similarity > maxSimilarity) {
|
if (similarity > maxSimilarity && !isNaN(similarity)) {
|
||||||
maxSimilarity = similarity;
|
maxSimilarity = similarity;
|
||||||
maxGraph = gb;
|
maxGraph = gb;
|
||||||
}
|
}
|
||||||
@ -171,5 +171,9 @@ export function overallSimilarity(
|
|||||||
const reduction =
|
const reduction =
|
||||||
1 / (1 + Math.abs(a.unreachable.size - b.unreachable.size));
|
1 / (1 + Math.abs(a.unreachable.size - b.unreachable.size));
|
||||||
// 取根号使结果更接近线性
|
// 取根号使结果更接近线性
|
||||||
return Math.sqrt(totalSimilarity / graphsA.length) * reduction;
|
if (graphsA.length === 0) {
|
||||||
|
return 0;
|
||||||
|
} else {
|
||||||
|
return Math.sqrt(totalSimilarity / graphsA.length) * reduction;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
41
ginka/model/input.py
Normal file
41
ginka/model/input.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class ResidualUpsampleBlock(nn.Module):
|
||||||
|
def __init__(self, in_ch, out_ch):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
||||||
|
nn.Conv2d(in_ch, out_ch, 3, padding=1),
|
||||||
|
nn.InstanceNorm2d(out_ch),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Conv2d(out_ch, out_ch, 3, padding=1),
|
||||||
|
nn.InstanceNorm2d(out_ch),
|
||||||
|
nn.GELU()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
class GinkaInput(nn.Module):
|
||||||
|
def __init__(self, feat_dim=1024, out_ch=64):
|
||||||
|
super().__init__()
|
||||||
|
fc_dim = out_ch * 8 * 4 * 4
|
||||||
|
self.out_ch = out_ch
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Linear(feat_dim, fc_dim),
|
||||||
|
nn.BatchNorm1d(fc_dim),
|
||||||
|
nn.ReLU()
|
||||||
|
)
|
||||||
|
self.upsample = nn.Sequential(
|
||||||
|
ResidualUpsampleBlock(out_ch*8, out_ch*8),
|
||||||
|
ResidualUpsampleBlock(out_ch*8, out_ch*4),
|
||||||
|
ResidualUpsampleBlock(out_ch*4, out_ch)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc(x)
|
||||||
|
x = x.view(-1, self.out_ch*8, 4, 4)
|
||||||
|
x = self.upsample(x)
|
||||||
|
return x
|
||||||
@ -2,49 +2,31 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from .unet import GinkaUNet
|
from .unet import GinkaUNet
|
||||||
from .sample import MapDownSample
|
from .input import GinkaInput
|
||||||
|
from .output import GinkaOutput
|
||||||
|
|
||||||
class GinkaModel(nn.Module):
|
class GinkaModel(nn.Module):
|
||||||
def __init__(self, feat_dim=1024, base_ch=64, num_classes=32):
|
def __init__(self, feat_dim=1024, base_ch=64, num_classes=32):
|
||||||
"""Ginka Model 模型定义部分
|
"""Ginka Model 模型定义部分
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.base_ch = base_ch
|
self.input = GinkaInput(feat_dim, base_ch)
|
||||||
fc_dim = base_ch * 8 * 4 * 4
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Linear(feat_dim, fc_dim),
|
|
||||||
nn.BatchNorm1d(fc_dim),
|
|
||||||
nn.ReLU()
|
|
||||||
)
|
|
||||||
self.upsample = nn.Sequential(
|
|
||||||
nn.Conv2d(base_ch*8, base_ch*16, kernel_size=3, stride=1, padding=1),
|
|
||||||
nn.PixelShuffle(2),
|
|
||||||
nn.InstanceNorm2d(base_ch*4),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Conv2d(base_ch*4, base_ch*8, kernel_size=3, stride=1, padding=1),
|
|
||||||
nn.PixelShuffle(2),
|
|
||||||
nn.InstanceNorm2d(base_ch*2),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Conv2d(base_ch*2, base_ch*4, kernel_size=3, stride=1, padding=1),
|
|
||||||
nn.PixelShuffle(2),
|
|
||||||
nn.InstanceNorm2d(base_ch),
|
|
||||||
nn.ReLU(),
|
|
||||||
)
|
|
||||||
self.unet = GinkaUNet(base_ch, num_classes)
|
self.unet = GinkaUNet(base_ch, num_classes)
|
||||||
self.down_sample = MapDownSample(num_classes, num_classes)
|
self.output = GinkaOutput(num_classes, (13, 13))
|
||||||
self.pool = nn.AdaptiveMaxPool2d((13, 13))
|
print(f"Input parameters: {sum(p.numel() for p in self.input.parameters())}")
|
||||||
|
print(f"UNet parameters: {sum(p.numel() for p in self.unet.parameters())}")
|
||||||
|
print(f"Output parameters: {sum(p.numel() for p in self.output.parameters())}")
|
||||||
|
print(f"Total parameters: {sum(p.numel() for p in self.parameters())}")
|
||||||
|
|
||||||
def forward(self, feat):
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
feat: 参考地图的特征向量
|
feat: 参考地图的特征向量
|
||||||
Returns:
|
Returns:
|
||||||
logits: 输出logits [BS, num_classes, H, W]
|
logits: 输出logits [BS, num_classes, H, W]
|
||||||
"""
|
"""
|
||||||
x = self.fc(feat)
|
x = self.input(x)
|
||||||
x = x.view(-1, self.base_ch*8, 4, 4)
|
|
||||||
x = self.upsample(x)
|
|
||||||
x = self.unet(x)
|
x = self.unet(x)
|
||||||
x = F.interpolate(x, (13, 13), mode='bilinear')
|
x = self.output(x)
|
||||||
return x, F.softmax(x, dim=1)
|
return x, F.softmax(x, dim=1)
|
||||||
|
|
||||||
10
ginka/model/output.py
Normal file
10
ginka/model/output.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class GinkaOutput(nn.Module):
|
||||||
|
def __init__(self, num_classes=32, out_size=(13, 13)):
|
||||||
|
super().__init__()
|
||||||
|
self.pool = nn.AdaptiveAvgPool2d(out_size)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.pool(x)
|
||||||
@ -1,15 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
class MapDownSample(nn.Module):
|
|
||||||
def __init__(self, in_ch=32, out_ch=32):
|
|
||||||
super().__init__()
|
|
||||||
self.down = nn.Sequential(
|
|
||||||
nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Conv2d(in_ch, out_ch, 4, stride=1, padding=0)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.down(x)
|
|
||||||
return x
|
|
||||||
@ -21,11 +21,11 @@ class GinkaEncoder(nn.Module):
|
|||||||
elif block == 'SEBlock':
|
elif block == 'SEBlock':
|
||||||
self.conv.append(SEBlock(out_channels))
|
self.conv.append(SEBlock(out_channels))
|
||||||
self.conv.append(nn.GELU())
|
self.conv.append(nn.GELU())
|
||||||
self.pool = nn.MaxPool2d(2)
|
self.down = nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=1)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x_res = self.conv(x)
|
x_res = self.conv(x)
|
||||||
x_down = self.pool(x_res)
|
x_down = self.down(x_res)
|
||||||
return x_down, x_res
|
return x_down, x_res
|
||||||
|
|
||||||
class GinkaDecoder(nn.Module):
|
class GinkaDecoder(nn.Module):
|
||||||
@ -53,23 +53,24 @@ class GinkaDecoder(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
class GinkaBottleneck(nn.Module):
|
class GinkaBottleneck(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels):
|
def __init__(self, in_channels, out_channels, attention=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Sequential(
|
self.conv = nn.Sequential(
|
||||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
||||||
nn.InstanceNorm2d(out_channels),
|
nn.InstanceNorm2d(out_channels),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
SEBlock(out_channels),
|
|
||||||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
||||||
nn.InstanceNorm2d(out_channels),
|
nn.InstanceNorm2d(out_channels),
|
||||||
nn.GELU(),
|
|
||||||
)
|
)
|
||||||
|
if attention:
|
||||||
|
self.conv.append(SEBlock(out_channels))
|
||||||
|
self.conv.append(nn.GELU())
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.conv(x)
|
return self.conv(x)
|
||||||
|
|
||||||
class GinkaUNet(nn.Module):
|
class GinkaUNet(nn.Module):
|
||||||
def __init__(self, in_ch=32, out_ch=32):
|
def __init__(self, in_ch=64, out_ch=32):
|
||||||
"""Ginka Model UNet 部分
|
"""Ginka Model UNet 部分
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -78,7 +79,7 @@ class GinkaUNet(nn.Module):
|
|||||||
self.down3 = GinkaEncoder(in_ch*4, in_ch*8, attention=True, block='SEBlock')
|
self.down3 = GinkaEncoder(in_ch*4, in_ch*8, attention=True, block='SEBlock')
|
||||||
self.down4 = GinkaEncoder(in_ch*8, in_ch*16, attention=True, block='SEBlock')
|
self.down4 = GinkaEncoder(in_ch*8, in_ch*16, attention=True, block='SEBlock')
|
||||||
|
|
||||||
self.bottleneck = GinkaBottleneck(in_ch*16, in_ch*16)
|
self.bottleneck = GinkaBottleneck(in_ch*16, in_ch*16, attention=True)
|
||||||
|
|
||||||
self.up1 = GinkaDecoder(in_ch*16, in_ch*8, attention=True, block='SEBlock')
|
self.up1 = GinkaDecoder(in_ch*16, in_ch*8, attention=True, block='SEBlock')
|
||||||
self.up2 = GinkaDecoder(in_ch*8, in_ch*4, attention=True, block='SEBlock')
|
self.up2 = GinkaDecoder(in_ch*8, in_ch*4, attention=True, block='SEBlock')
|
||||||
@ -97,9 +98,9 @@ class GinkaUNet(nn.Module):
|
|||||||
|
|
||||||
x = self.bottleneck(x_down4)
|
x = self.bottleneck(x_down4)
|
||||||
|
|
||||||
x = self.up1(x, skip4) # 用 down2 的 skip
|
x = self.up1(x, skip4)
|
||||||
x = self.up2(x, skip3) # 用 down2 的 skip
|
x = self.up2(x, skip3)
|
||||||
x = self.up3(x, skip2) # 用 down1 的 skip
|
x = self.up3(x, skip2)
|
||||||
x = self.up4(x, skip1) # 用 down1 的 skip
|
x = self.up4(x, skip1)
|
||||||
|
|
||||||
return self.final(x)
|
return self.final(x)
|
||||||
|
|||||||
@ -91,8 +91,6 @@ def train():
|
|||||||
graph1 = graph1.to(device)
|
graph1 = graph1.to(device)
|
||||||
graph2 = graph2.to(device)
|
graph2 = graph2.to(device)
|
||||||
|
|
||||||
# print(map1.shape, map2.shape)
|
|
||||||
|
|
||||||
# 前向传播
|
# 前向传播
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
vision_feat1, topo_feat1 = model(map1, graph1)
|
vision_feat1, topo_feat1 = model(map1, graph1)
|
||||||
|
|||||||
@ -44,8 +44,6 @@ def validate():
|
|||||||
vision_feat1, topo_feat1 = model(map1_val, graph1)
|
vision_feat1, topo_feat1 = model(map1_val, graph1)
|
||||||
vision_feat2, topo_feat2 = model(map2_val, graph2)
|
vision_feat2, topo_feat2 = model(map2_val, graph2)
|
||||||
|
|
||||||
print(vision_feat1.isnan().any().item(), topo_feat1.isnan().any().item(), vision_feat2.isnan().any().item(), topo_feat2.isnan().any().item())
|
|
||||||
|
|
||||||
vision_pred_val = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
vision_pred_val = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
||||||
topo_pred_val = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
topo_pred_val = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||||
loss_val = criterion(
|
loss_val = criterion(
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch_geometric.data import Data, Batch
|
from torch_geometric.data import Data, Batch
|
||||||
|
|
||||||
|
|
||||||
def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data:
|
def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data:
|
||||||
"""
|
"""
|
||||||
可导的图结构转换(返回 PyG Data 对象)
|
可导的图结构转换(返回 PyG Data 对象)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user