mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
refactor: minamo vision 改为向量输出
This commit is contained in:
parent
7d0f567cc1
commit
ef9d3d1504
@ -1,7 +1,7 @@
|
||||
import torch.nn as nn
|
||||
|
||||
class MinamoLoss(nn.Module):
|
||||
def __init__(self, vision_weight=0, topo_weight=1):
|
||||
def __init__(self, vision_weight=1, topo_weight=0):
|
||||
super().__init__()
|
||||
self.vision_weight = vision_weight
|
||||
self.topo_weight = topo_weight
|
||||
|
||||
@ -4,17 +4,21 @@ from .vision import MinamoVisionModel
|
||||
from .topo import MinamoTopoModel
|
||||
|
||||
class MinamoModel(nn.Module):
|
||||
def __init__(self, tile_types=32, embedding_dim=16, conv_channels=16):
|
||||
def __init__(self, tile_types=32):
|
||||
super().__init__()
|
||||
# 视觉相似度部分
|
||||
self.vision_model = MinamoVisionModel(tile_types, embedding_dim, conv_channels)
|
||||
self.vision_model = MinamoVisionModel(tile_types)
|
||||
# 拓扑相似度部分
|
||||
self.topo_model = MinamoTopoModel(tile_types)
|
||||
|
||||
def forward(self, map1, map2, graph1, graph2):
|
||||
vision_sim = self.vision_model(map1, map2)
|
||||
vision_feat1 = self.vision_model(map1)
|
||||
vision_feat2 = self.vision_model(map2)
|
||||
|
||||
topo_feat1 = self.topo_model(graph1)
|
||||
topo_feat2 = self.topo_model(graph2)
|
||||
|
||||
return vision_sim, F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||
vision_sim = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
||||
topo_sim = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||
|
||||
return vision_sim, topo_sim
|
||||
|
||||
@ -6,7 +6,7 @@ from torch_geometric.data import Data
|
||||
|
||||
class MinamoTopoModel(nn.Module):
|
||||
def __init__(
|
||||
self, tile_types=32, emb_dim=16, hidden_dim=32, out_dim=16, mlp_dim=8
|
||||
self, tile_types=32, emb_dim=64, hidden_dim=64, out_dim=512, mlp_dim=128
|
||||
):
|
||||
super().__init__()
|
||||
# 嵌入层
|
||||
@ -27,9 +27,7 @@ class MinamoTopoModel(nn.Module):
|
||||
|
||||
# 增强MLP
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(out_dim, mlp_dim*2),
|
||||
nn.ReLU(),
|
||||
nn.Linear(mlp_dim*2, mlp_dim)
|
||||
nn.Linear(out_dim, mlp_dim),
|
||||
)
|
||||
|
||||
def forward(self, graph: Data):
|
||||
@ -50,6 +48,8 @@ class MinamoTopoModel(nn.Module):
|
||||
# x, _, _, batch, _, _ = self.pool(x, graph.edge_index, batch=graph.batch)
|
||||
x = global_mean_pool(x, graph.batch)
|
||||
|
||||
topo_vec = self.fc(x)
|
||||
|
||||
# 增强MLP
|
||||
return self.fc(x)
|
||||
return F.normalize(topo_vec, p=2, dim=-1)
|
||||
|
||||
@ -3,27 +3,35 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class DualAttention(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
def __init__(self, in_channels, reduction=8):
|
||||
super().__init__()
|
||||
# 空间注意力
|
||||
self.spatial = nn.Sequential(
|
||||
nn.Conv2d(in_channels, 1, 1),
|
||||
nn.Conv2d(in_channels, 1, 3, padding=1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
# 通道注意力
|
||||
|
||||
self.channel = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Conv2d(in_channels, in_channels//8, 1),
|
||||
nn.Conv2d(in_channels, in_channels // reduction, 1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels//8, in_channels, 1),
|
||||
nn.Conv2d(in_channels // reduction, in_channels, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.channel_max = nn.Sequential(
|
||||
nn.AdaptiveMaxPool2d(1),
|
||||
nn.Conv2d(in_channels, in_channels // reduction, 1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels // reduction, in_channels, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.spatial(x) + x * self.channel(x)
|
||||
attn = self.spatial(x) + self.channel(x) + self.channel_max(x)
|
||||
return x * attn
|
||||
|
||||
class MinamoVisionModel(nn.Module):
|
||||
def __init__(self, tile_types=32, embedding_dim=16, conv_channels=16):
|
||||
def __init__(self, tile_types=32, embedding_dim=16, conv_channels=64, out_dim=128):
|
||||
super().__init__()
|
||||
# 嵌入层处理不同图块类型
|
||||
self.embedding = nn.Embedding(tile_types, embedding_dim)
|
||||
@ -31,41 +39,41 @@ class MinamoVisionModel(nn.Module):
|
||||
# 卷积部分
|
||||
self.vision_conv = nn.Sequential(
|
||||
nn.Conv2d(embedding_dim, conv_channels, 3, padding=1),
|
||||
DualAttention(conv_channels),
|
||||
nn.BatchNorm2d(conv_channels),
|
||||
DualAttention(conv_channels, reduction=12),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2, 2),
|
||||
nn.Dropout2d(0.4),
|
||||
|
||||
nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1),
|
||||
DualAttention(conv_channels*2),
|
||||
nn.BatchNorm2d(conv_channels*2),
|
||||
DualAttention(conv_channels*2, reduction=12),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2, 2),
|
||||
nn.Dropout2d(0.4),
|
||||
|
||||
nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1),
|
||||
DualAttention(conv_channels*4),
|
||||
nn.BatchNorm2d(conv_channels*4),
|
||||
DualAttention(conv_channels*4, reduction=12),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.AdaptiveAvgPool2d(1)
|
||||
nn.AdaptiveMaxPool2d(1)
|
||||
)
|
||||
|
||||
# 预测头
|
||||
# 输出为向量
|
||||
self.vision_head = nn.Sequential(
|
||||
nn.Linear(conv_channels*4, conv_channels*2),
|
||||
nn.Dropout(0.4),
|
||||
nn.Linear(conv_channels*2, 1),
|
||||
nn.Sigmoid()
|
||||
nn.Dropout(0.5),
|
||||
nn.Linear(conv_channels*4, out_dim)
|
||||
)
|
||||
|
||||
def forward(self, map1, map2):
|
||||
e1 = self.embedding(map1).permute(0, 3, 1, 2)
|
||||
e2 = self.embedding(map2).permute(0, 3, 1, 2)
|
||||
def forward(self, map):
|
||||
x = self.embedding(map)
|
||||
# print(map.shape, x.shape)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
|
||||
x = self.vision_conv(x)
|
||||
x = x.view(x.size(0), -1) # 展平
|
||||
|
||||
v1 = self.vision_conv(e1)
|
||||
v2 = self.vision_conv(e2)
|
||||
vision_vec = self.vision_head(x)
|
||||
|
||||
v1 = v1.view(v1.size(0), -1) # 展平
|
||||
v2 = v2.view(v2.size(0), -1) # 展平
|
||||
|
||||
vision_sim = self.vision_head(torch.abs(v1 - v2))
|
||||
|
||||
return vision_sim
|
||||
return F.normalize(vision_vec, p=2, dim=-1) # 归一化
|
||||
|
||||
@ -52,7 +52,7 @@ def train():
|
||||
)
|
||||
|
||||
# 设定优化器与调度器
|
||||
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
|
||||
optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-2)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)
|
||||
criterion = MinamoLoss()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user