refactor: 拓扑改为图卷积结构

This commit is contained in:
unanmed 2025-03-16 22:47:38 +08:00
parent 0910bddba2
commit b43a8693ef
7 changed files with 159 additions and 112 deletions

View File

@ -1,6 +1,7 @@
import json
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data
def load_data(path: str):
with open(path, 'r', encoding="utf-8") as f:
@ -12,6 +13,33 @@ def load_data(path: str):
return data_list
def convert_map_to_graph(map):
rows = len(map)
cols = len(map[0])
node_indices = {}
valid_nodes = []
node_counter = 0
for r in range(rows):
for c in range(cols):
if map[r][c] != 1: # 排除墙体
node_indices[(r, c)] = node_counter
valid_nodes.append((r, c, map[r][c])) # (行, 列, 地形类型)
node_counter += 1
edge_list = []
for (r, c, _) in valid_nodes:
node = node_indices[(r, c)]
if c + 1 < cols and (r, c + 1) in node_indices:
edge_list.append((node, node_indices[(r, c + 1)]))
if r + 1 < rows and (r + 1, c) in node_indices:
edge_list.append((node, node_indices[(r + 1, c)]))
edge_index = torch.tensor(edge_list, dtype=torch.long).T
node_features = torch.tensor([node_type for (_, _, node_type) in valid_nodes], dtype=torch.long)
return Data(x=node_features, edge_index=edge_index)
class MinamoDataset(Dataset):
def __init__(self, data_path: str):
self.data = load_data(data_path) # 自定义数据加载函数
@ -25,5 +53,7 @@ class MinamoDataset(Dataset):
torch.LongTensor(item['map1']),
torch.LongTensor(item['map2']),
torch.FloatTensor([item['visionSimilarity']]),
torch.FloatTensor([item['topoSimilarity']])
torch.FloatTensor([item['topoSimilarity']]),
convert_map_to_graph(item['map1']),
convert_map_to_graph(item['map2'])
)

View File

@ -8,6 +8,7 @@ class MinamoLoss(nn.Module):
self.mse = nn.MSELoss()
def forward(self, vis_pred, topo_pred, vis_true, topo_true):
# print(vis_pred.shape, topo_pred.shape, vis_true.shape, topo_true.shape)
# print(vis_pred[0].item(), topo_pred[0].item(), vis_true[0].item(), topo_true[0].item())
vis_loss = self.mse(vis_pred, vis_true)
topo_loss = self.mse(topo_pred, topo_true)

View File

@ -1,114 +1,20 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class DualAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
# 空间注意力
self.spatial = nn.Sequential(
nn.Conv2d(in_channels, 1, 1),
nn.Sigmoid()
)
# 通道注意力
self.channel = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, in_channels//8, 1),
nn.ReLU(),
nn.Conv2d(in_channels//8, in_channels, 1),
nn.Sigmoid()
)
def forward(self, x):
return x * self.spatial(x) + x * self.channel(x)
class DirectionalAttention(nn.Module):
def __init__(self, kernel_size=7):
super().__init__()
self.direction_convs = nn.ModuleDict({
dir: nn.Conv2d(1, 1, kernel_size, padding=kernel_size//2,
padding_mode='replicate')
for dir in ['h', 'v', 'd1', 'd2']
})
def forward(self, x):
B, C, H, W = x.shape
# 各方向特征
h_att = self.direction_convs['h'](x.mean(1, keepdim=True))
v_att = self.direction_convs['v'](x.mean(1, keepdim=True))
d1_att = self.direction_convs['d1'](x.mean(1, keepdim=True))
d2_att = self.direction_convs['d2'](x.mean(1, keepdim=True))
# 动态融合
combined = torch.stack([h_att, v_att, d1_att, d2_att], dim=1) # [B,4,1,H,W]
att_weights = F.softmax(combined.mean([3,4]), dim=1) # [B,4]
return x * (combined * att_weights.unsqueeze(-1).unsqueeze(-1)).sum(1)
from .vision import MinamoVisionModel
from .topo import MinamoTopoModel
class MinamoModel(nn.Module):
def __init__(self, tile_types=32, embedding_dim=16, conv_channels=32):
def __init__(self, tile_types=32, embedding_dim=16, conv_channels=16):
super().__init__()
# 嵌入层处理不同图块类型
self.embedding = nn.Embedding(tile_types, embedding_dim)
self.vision_conv = nn.Sequential(
nn.Conv2d(embedding_dim, conv_channels, 3, padding=1),
DualAttention(conv_channels),
nn.BatchNorm2d(conv_channels),
nn.ReLU(),
nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1),
DualAttention(conv_channels*2),
nn.BatchNorm2d(conv_channels*2),
nn.ReLU(),
nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1),
DualAttention(conv_channels*4),
nn.BatchNorm2d(conv_channels*4),
nn.ReLU(),
nn.Conv2d(conv_channels*4, conv_channels*8, 3, padding=1),
DualAttention(conv_channels*8),
nn.BatchNorm2d(conv_channels*8),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1)
)
# 拓扑特征分支
self.topo_conv = nn.Sequential(
nn.Conv2d(embedding_dim, conv_channels, 5, padding=2), # 更大卷积核捕捉结构
nn.BatchNorm2d(conv_channels),
nn.ReLU(),
nn.Conv2d(conv_channels, conv_channels*2, 5, padding=2), # 更大卷积核捕捉结构
nn.BatchNorm2d(conv_channels*2),
nn.ReLU(),
nn.Conv2d(conv_channels*2, conv_channels*4, 5, padding=2), # 更大卷积核捕捉结构
nn.BatchNorm2d(conv_channels*4),
nn.ReLU(),
# nn.MaxPool2d(2),
# GraphConvLayer(128, 256), # 图卷积层
nn.AdaptiveMaxPool2d(1)
)
# 多任务预测头
self.vision_head = nn.Sequential(
nn.Linear(conv_channels*8, 1),
nn.Sigmoid()
)
self.topo_head = nn.Sequential(
nn.Linear(conv_channels*4, 1),
nn.Sigmoid()
)
# 视觉相似度部分
self.vision_model = MinamoVisionModel(tile_types, embedding_dim, conv_channels)
# 拓扑相似度部分
self.topo_model = MinamoTopoModel(tile_types)
def forward(self, map1, map2):
e1 = self.embedding(map1).permute(0, 3, 1, 2)
e2 = self.embedding(map2).permute(0, 3, 1, 2)
def forward(self, map1, map2, graph1, graph2):
vision_sim = self.vision_model(map1, map2)
v1 = self.vision_conv(e1).squeeze()
v2 = self.vision_conv(e2).squeeze()
topo_feat1 = self.topo_model(graph1)
topo_feat2 = self.topo_model(graph2)
t1 = self.topo_conv(e1).squeeze()
t2 = self.topo_conv(e2).squeeze()
# 多任务输出
vision_sim = self.vision_head(torch.abs(v1 - v2))
topo_sim = self.topo_head(torch.abs(t1 - t2))
return vision_sim, topo_sim
return vision_sim, F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)

29
minamo/model/topo.py Normal file
View File

@ -0,0 +1,29 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data
class MinamoTopoModel(nn.Module):
def __init__(
self, tile_types=32, emb_dim=16, hidden_dim=32, out_dim=16, mlp_dim=8
):
super().__init__()
# 嵌入层
self.embedding = torch.nn.Embedding(tile_types, emb_dim)
# 图卷积层
self.conv1 = GCNConv(emb_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, out_dim)
self.fc = torch.nn.Linear(out_dim, mlp_dim) # 降维全连接层
def forward(self, graph: Data):
x = self.embedding(graph.x)
x = self.conv1(x, graph.edge_index)
x = F.relu(x)
x = self.conv2(x, graph.edge_index)
x = global_mean_pool(x, graph.batch)
# 全连接层降维
x = self.fc(x)
return x # (batch_size, mlp_dim)

71
minamo/model/vision.py Normal file
View File

@ -0,0 +1,71 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class DualAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
# 空间注意力
self.spatial = nn.Sequential(
nn.Conv2d(in_channels, 1, 1),
nn.Sigmoid()
)
# 通道注意力
self.channel = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, in_channels//8, 1),
nn.ReLU(),
nn.Conv2d(in_channels//8, in_channels, 1),
nn.Sigmoid()
)
def forward(self, x):
return x * self.spatial(x) + x * self.channel(x)
class MinamoVisionModel(nn.Module):
def __init__(self, tile_types=32, embedding_dim=16, conv_channels=16):
super().__init__()
# 嵌入层处理不同图块类型
self.embedding = nn.Embedding(tile_types, embedding_dim)
# 卷积部分
self.vision_conv = nn.Sequential(
nn.Conv2d(embedding_dim, conv_channels, 3, padding=1),
DualAttention(conv_channels),
nn.BatchNorm2d(conv_channels),
nn.ReLU(),
nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1),
DualAttention(conv_channels*2),
nn.BatchNorm2d(conv_channels*2),
nn.ReLU(),
nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1),
DualAttention(conv_channels*4),
nn.BatchNorm2d(conv_channels*4),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1)
)
# 预测头
self.vision_head = nn.Sequential(
nn.Linear(conv_channels*4, conv_channels*2),
nn.Dropout(0.4),
nn.Linear(conv_channels*2, 1),
nn.Sigmoid()
)
def forward(self, map1, map2):
e1 = self.embedding(map1).permute(0, 3, 1, 2)
e2 = self.embedding(map2).permute(0, 3, 1, 2)
v1 = self.vision_conv(e1)
v2 = self.vision_conv(e2)
v1 = v1.view(v1.size(0), -1) # 展平
v2 = v2.view(v2.size(0), -1) # 展平
vision_sim = self.vision_head(torch.abs(v1 - v2))
return vision_sim

View File

@ -2,7 +2,7 @@ import os
from datetime import datetime
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from .model.model import MinamoModel
from .model.loss import MinamoLoss
@ -63,17 +63,19 @@ def train():
for batch in dataloader:
# 数据迁移到设备
map1, map2, vision_simi, topo_simi = batch
map1, map2, vision_simi, topo_simi, graph1, graph2 = batch
map1 = map1.to(device) # 转为 [B, C, H, W]
map2 = map2.to(device)
topo_simi = topo_simi.to(device)
vision_simi = vision_simi.to(device)
graph1 = graph1.to(device)
graph2 = graph2.to(device)
# print(map1.shape, map2.shape)
# 前向传播
optimizer.zero_grad()
vision_pred, topo_pred = model(map1, map2)
vision_pred, topo_pred = model(map1, map2, graph1, graph2)
# 计算损失
loss = criterion(vision_pred, topo_pred, vision_simi, topo_simi)
@ -103,13 +105,15 @@ def train():
val_loss = 0
with torch.no_grad():
for val_batch in val_loader:
map1_val, map2_val, vision_simi_val, topo_simi_val = val_batch
map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = val_batch
map1_val = map1_val.to(device)
map2_val = map2_val.to(device)
vision_simi_val = vision_simi_val.to(device)
topo_simi_val = topo_simi_val.to(device)
graph1 = graph1.to(device)
graph2 = graph2.to(device)
vision_pred_val, topo_pred_val = model(map1_val, map2_val)
vision_pred_val, topo_pred_val = model(map1_val, map2_val, graph1, graph2)
loss_val = criterion(
vision_pred_val, topo_pred_val,
vision_simi_val, topo_simi_val

6
requirements.txt Normal file
View File

@ -0,0 +1,6 @@
torch
torchvision
torchaudio
tqdm
torch-geometric
transformers