From ef9d3d1504c0b3ab2dcd27693a5df0474a01d346 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Mon, 17 Mar 2025 19:50:46 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20minamo=20vision=20=E6=94=B9?= =?UTF-8?q?=E4=B8=BA=E5=90=91=E9=87=8F=E8=BE=93=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- minamo/model/loss.py | 2 +- minamo/model/model.py | 12 +++++--- minamo/model/topo.py | 10 +++---- minamo/model/vision.py | 64 ++++++++++++++++++++++++------------------ minamo/train.py | 2 +- 5 files changed, 51 insertions(+), 39 deletions(-) diff --git a/minamo/model/loss.py b/minamo/model/loss.py index 671b137..3212724 100644 --- a/minamo/model/loss.py +++ b/minamo/model/loss.py @@ -1,7 +1,7 @@ import torch.nn as nn class MinamoLoss(nn.Module): - def __init__(self, vision_weight=0, topo_weight=1): + def __init__(self, vision_weight=1, topo_weight=0): super().__init__() self.vision_weight = vision_weight self.topo_weight = topo_weight diff --git a/minamo/model/model.py b/minamo/model/model.py index 9ec2fea..25c322f 100644 --- a/minamo/model/model.py +++ b/minamo/model/model.py @@ -4,17 +4,21 @@ from .vision import MinamoVisionModel from .topo import MinamoTopoModel class MinamoModel(nn.Module): - def __init__(self, tile_types=32, embedding_dim=16, conv_channels=16): + def __init__(self, tile_types=32): super().__init__() # 视觉相似度部分 - self.vision_model = MinamoVisionModel(tile_types, embedding_dim, conv_channels) + self.vision_model = MinamoVisionModel(tile_types) # 拓扑相似度部分 self.topo_model = MinamoTopoModel(tile_types) def forward(self, map1, map2, graph1, graph2): - vision_sim = self.vision_model(map1, map2) + vision_feat1 = self.vision_model(map1) + vision_feat2 = self.vision_model(map2) topo_feat1 = self.topo_model(graph1) topo_feat2 = self.topo_model(graph2) - return vision_sim, F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1) + vision_sim = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1) + topo_sim = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1) + + return vision_sim, topo_sim diff --git a/minamo/model/topo.py b/minamo/model/topo.py index d1b4220..d0910e5 100644 --- a/minamo/model/topo.py +++ b/minamo/model/topo.py @@ -6,7 +6,7 @@ from torch_geometric.data import Data class MinamoTopoModel(nn.Module): def __init__( - self, tile_types=32, emb_dim=16, hidden_dim=32, out_dim=16, mlp_dim=8 + self, tile_types=32, emb_dim=64, hidden_dim=64, out_dim=512, mlp_dim=128 ): super().__init__() # 嵌入层 @@ -27,9 +27,7 @@ class MinamoTopoModel(nn.Module): # 增强MLP self.fc = nn.Sequential( - nn.Linear(out_dim, mlp_dim*2), - nn.ReLU(), - nn.Linear(mlp_dim*2, mlp_dim) + nn.Linear(out_dim, mlp_dim), ) def forward(self, graph: Data): @@ -50,6 +48,8 @@ class MinamoTopoModel(nn.Module): # x, _, _, batch, _, _ = self.pool(x, graph.edge_index, batch=graph.batch) x = global_mean_pool(x, graph.batch) + topo_vec = self.fc(x) + # 增强MLP - return self.fc(x) + return F.normalize(topo_vec, p=2, dim=-1) \ No newline at end of file diff --git a/minamo/model/vision.py b/minamo/model/vision.py index a8fb99c..38ddbe0 100644 --- a/minamo/model/vision.py +++ b/minamo/model/vision.py @@ -3,27 +3,35 @@ import torch.nn as nn import torch.nn.functional as F class DualAttention(nn.Module): - def __init__(self, in_channels): + def __init__(self, in_channels, reduction=8): super().__init__() - # 空间注意力 self.spatial = nn.Sequential( - nn.Conv2d(in_channels, 1, 1), + nn.Conv2d(in_channels, 1, 3, padding=1), nn.Sigmoid() ) - # 通道注意力 + self.channel = nn.Sequential( nn.AdaptiveAvgPool2d(1), - nn.Conv2d(in_channels, in_channels//8, 1), + nn.Conv2d(in_channels, in_channels // reduction, 1), nn.ReLU(), - nn.Conv2d(in_channels//8, in_channels, 1), + nn.Conv2d(in_channels // reduction, in_channels, 1), + nn.Sigmoid() + ) + + self.channel_max = nn.Sequential( + nn.AdaptiveMaxPool2d(1), + nn.Conv2d(in_channels, in_channels // reduction, 1), + nn.ReLU(), + nn.Conv2d(in_channels // reduction, in_channels, 1), nn.Sigmoid() ) def forward(self, x): - return x * self.spatial(x) + x * self.channel(x) + attn = self.spatial(x) + self.channel(x) + self.channel_max(x) + return x * attn class MinamoVisionModel(nn.Module): - def __init__(self, tile_types=32, embedding_dim=16, conv_channels=16): + def __init__(self, tile_types=32, embedding_dim=16, conv_channels=64, out_dim=128): super().__init__() # 嵌入层处理不同图块类型 self.embedding = nn.Embedding(tile_types, embedding_dim) @@ -31,41 +39,41 @@ class MinamoVisionModel(nn.Module): # 卷积部分 self.vision_conv = nn.Sequential( nn.Conv2d(embedding_dim, conv_channels, 3, padding=1), - DualAttention(conv_channels), nn.BatchNorm2d(conv_channels), + DualAttention(conv_channels, reduction=12), nn.ReLU(), + nn.MaxPool2d(2, 2), + nn.Dropout2d(0.4), nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1), - DualAttention(conv_channels*2), nn.BatchNorm2d(conv_channels*2), + DualAttention(conv_channels*2, reduction=12), nn.ReLU(), + nn.MaxPool2d(2, 2), + nn.Dropout2d(0.4), nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1), - DualAttention(conv_channels*4), nn.BatchNorm2d(conv_channels*4), + DualAttention(conv_channels*4, reduction=12), nn.ReLU(), - nn.AdaptiveAvgPool2d(1) + nn.AdaptiveMaxPool2d(1) ) - # 预测头 + # 输出为向量 self.vision_head = nn.Sequential( - nn.Linear(conv_channels*4, conv_channels*2), - nn.Dropout(0.4), - nn.Linear(conv_channels*2, 1), - nn.Sigmoid() + nn.Dropout(0.5), + nn.Linear(conv_channels*4, out_dim) ) - def forward(self, map1, map2): - e1 = self.embedding(map1).permute(0, 3, 1, 2) - e2 = self.embedding(map2).permute(0, 3, 1, 2) + def forward(self, map): + x = self.embedding(map) + # print(map.shape, x.shape) + x = x.permute(0, 3, 1, 2) + + x = self.vision_conv(x) + x = x.view(x.size(0), -1) # 展平 - v1 = self.vision_conv(e1) - v2 = self.vision_conv(e2) + vision_vec = self.vision_head(x) - v1 = v1.view(v1.size(0), -1) # 展平 - v2 = v2.view(v2.size(0), -1) # 展平 - - vision_sim = self.vision_head(torch.abs(v1 - v2)) - - return vision_sim \ No newline at end of file + return F.normalize(vision_vec, p=2, dim=-1) # 归一化 diff --git a/minamo/train.py b/minamo/train.py index 7db7294..16b8263 100644 --- a/minamo/train.py +++ b/minamo/train.py @@ -52,7 +52,7 @@ def train(): ) # 设定优化器与调度器 - optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2) + optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-2) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6) criterion = MinamoLoss()