diff --git a/minamo/dataset.py b/minamo/dataset.py index 31b89b8..0ff667e 100644 --- a/minamo/dataset.py +++ b/minamo/dataset.py @@ -1,6 +1,7 @@ import json import torch from torch.utils.data import Dataset +from torch_geometric.data import Data def load_data(path: str): with open(path, 'r', encoding="utf-8") as f: @@ -12,6 +13,33 @@ def load_data(path: str): return data_list +def convert_map_to_graph(map): + rows = len(map) + cols = len(map[0]) + node_indices = {} + valid_nodes = [] + node_counter = 0 + + for r in range(rows): + for c in range(cols): + if map[r][c] != 1: # 排除墙体 + node_indices[(r, c)] = node_counter + valid_nodes.append((r, c, map[r][c])) # (行, 列, 地形类型) + node_counter += 1 + + edge_list = [] + for (r, c, _) in valid_nodes: + node = node_indices[(r, c)] + if c + 1 < cols and (r, c + 1) in node_indices: + edge_list.append((node, node_indices[(r, c + 1)])) + if r + 1 < rows and (r + 1, c) in node_indices: + edge_list.append((node, node_indices[(r + 1, c)])) + + edge_index = torch.tensor(edge_list, dtype=torch.long).T + node_features = torch.tensor([node_type for (_, _, node_type) in valid_nodes], dtype=torch.long) + + return Data(x=node_features, edge_index=edge_index) + class MinamoDataset(Dataset): def __init__(self, data_path: str): self.data = load_data(data_path) # 自定义数据加载函数 @@ -25,5 +53,7 @@ class MinamoDataset(Dataset): torch.LongTensor(item['map1']), torch.LongTensor(item['map2']), torch.FloatTensor([item['visionSimilarity']]), - torch.FloatTensor([item['topoSimilarity']]) + torch.FloatTensor([item['topoSimilarity']]), + convert_map_to_graph(item['map1']), + convert_map_to_graph(item['map2']) ) diff --git a/minamo/model/loss.py b/minamo/model/loss.py index 6c292b8..fe99bca 100644 --- a/minamo/model/loss.py +++ b/minamo/model/loss.py @@ -8,6 +8,7 @@ class MinamoLoss(nn.Module): self.mse = nn.MSELoss() def forward(self, vis_pred, topo_pred, vis_true, topo_true): + # print(vis_pred.shape, topo_pred.shape, vis_true.shape, topo_true.shape) # print(vis_pred[0].item(), topo_pred[0].item(), vis_true[0].item(), topo_true[0].item()) vis_loss = self.mse(vis_pred, vis_true) topo_loss = self.mse(topo_pred, topo_true) diff --git a/minamo/model/model.py b/minamo/model/model.py index 84797bb..9ec2fea 100644 --- a/minamo/model/model.py +++ b/minamo/model/model.py @@ -1,114 +1,20 @@ -import torch import torch.nn as nn import torch.nn.functional as F - -class DualAttention(nn.Module): - def __init__(self, in_channels): - super().__init__() - # 空间注意力 - self.spatial = nn.Sequential( - nn.Conv2d(in_channels, 1, 1), - nn.Sigmoid() - ) - # 通道注意力 - self.channel = nn.Sequential( - nn.AdaptiveAvgPool2d(1), - nn.Conv2d(in_channels, in_channels//8, 1), - nn.ReLU(), - nn.Conv2d(in_channels//8, in_channels, 1), - nn.Sigmoid() - ) - - def forward(self, x): - return x * self.spatial(x) + x * self.channel(x) - -class DirectionalAttention(nn.Module): - def __init__(self, kernel_size=7): - super().__init__() - self.direction_convs = nn.ModuleDict({ - dir: nn.Conv2d(1, 1, kernel_size, padding=kernel_size//2, - padding_mode='replicate') - for dir in ['h', 'v', 'd1', 'd2'] - }) - - def forward(self, x): - B, C, H, W = x.shape - # 各方向特征 - h_att = self.direction_convs['h'](x.mean(1, keepdim=True)) - v_att = self.direction_convs['v'](x.mean(1, keepdim=True)) - d1_att = self.direction_convs['d1'](x.mean(1, keepdim=True)) - d2_att = self.direction_convs['d2'](x.mean(1, keepdim=True)) - - # 动态融合 - combined = torch.stack([h_att, v_att, d1_att, d2_att], dim=1) # [B,4,1,H,W] - att_weights = F.softmax(combined.mean([3,4]), dim=1) # [B,4] - return x * (combined * att_weights.unsqueeze(-1).unsqueeze(-1)).sum(1) +from .vision import MinamoVisionModel +from .topo import MinamoTopoModel class MinamoModel(nn.Module): - def __init__(self, tile_types=32, embedding_dim=16, conv_channels=32): + def __init__(self, tile_types=32, embedding_dim=16, conv_channels=16): super().__init__() - # 嵌入层处理不同图块类型 - self.embedding = nn.Embedding(tile_types, embedding_dim) - - self.vision_conv = nn.Sequential( - nn.Conv2d(embedding_dim, conv_channels, 3, padding=1), - DualAttention(conv_channels), - nn.BatchNorm2d(conv_channels), - nn.ReLU(), - nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1), - DualAttention(conv_channels*2), - nn.BatchNorm2d(conv_channels*2), - nn.ReLU(), - nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1), - DualAttention(conv_channels*4), - nn.BatchNorm2d(conv_channels*4), - nn.ReLU(), - nn.Conv2d(conv_channels*4, conv_channels*8, 3, padding=1), - DualAttention(conv_channels*8), - nn.BatchNorm2d(conv_channels*8), - nn.ReLU(), - nn.AdaptiveAvgPool2d(1) - ) - - # 拓扑特征分支 - self.topo_conv = nn.Sequential( - nn.Conv2d(embedding_dim, conv_channels, 5, padding=2), # 更大卷积核捕捉结构 - nn.BatchNorm2d(conv_channels), - nn.ReLU(), - nn.Conv2d(conv_channels, conv_channels*2, 5, padding=2), # 更大卷积核捕捉结构 - nn.BatchNorm2d(conv_channels*2), - nn.ReLU(), - nn.Conv2d(conv_channels*2, conv_channels*4, 5, padding=2), # 更大卷积核捕捉结构 - nn.BatchNorm2d(conv_channels*4), - nn.ReLU(), - # nn.MaxPool2d(2), - # GraphConvLayer(128, 256), # 图卷积层 - nn.AdaptiveMaxPool2d(1) - ) - - # 多任务预测头 - self.vision_head = nn.Sequential( - nn.Linear(conv_channels*8, 1), - nn.Sigmoid() - ) - - self.topo_head = nn.Sequential( - nn.Linear(conv_channels*4, 1), - nn.Sigmoid() - ) + # 视觉相似度部分 + self.vision_model = MinamoVisionModel(tile_types, embedding_dim, conv_channels) + # 拓扑相似度部分 + self.topo_model = MinamoTopoModel(tile_types) - def forward(self, map1, map2): - e1 = self.embedding(map1).permute(0, 3, 1, 2) - e2 = self.embedding(map2).permute(0, 3, 1, 2) + def forward(self, map1, map2, graph1, graph2): + vision_sim = self.vision_model(map1, map2) - v1 = self.vision_conv(e1).squeeze() - v2 = self.vision_conv(e2).squeeze() + topo_feat1 = self.topo_model(graph1) + topo_feat2 = self.topo_model(graph2) - t1 = self.topo_conv(e1).squeeze() - t2 = self.topo_conv(e2).squeeze() - - # 多任务输出 - vision_sim = self.vision_head(torch.abs(v1 - v2)) - topo_sim = self.topo_head(torch.abs(t1 - t2)) - - return vision_sim, topo_sim + return vision_sim, F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1) diff --git a/minamo/model/topo.py b/minamo/model/topo.py new file mode 100644 index 0000000..6ddb236 --- /dev/null +++ b/minamo/model/topo.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.nn import GCNConv, global_mean_pool +from torch_geometric.data import Data + +class MinamoTopoModel(nn.Module): + def __init__( + self, tile_types=32, emb_dim=16, hidden_dim=32, out_dim=16, mlp_dim=8 + ): + super().__init__() + # 嵌入层 + self.embedding = torch.nn.Embedding(tile_types, emb_dim) + # 图卷积层 + self.conv1 = GCNConv(emb_dim, hidden_dim) + self.conv2 = GCNConv(hidden_dim, out_dim) + self.fc = torch.nn.Linear(out_dim, mlp_dim) # 降维全连接层 + + def forward(self, graph: Data): + x = self.embedding(graph.x) + x = self.conv1(x, graph.edge_index) + x = F.relu(x) + x = self.conv2(x, graph.edge_index) + x = global_mean_pool(x, graph.batch) + + # 全连接层降维 + x = self.fc(x) + return x # (batch_size, mlp_dim) + \ No newline at end of file diff --git a/minamo/model/vision.py b/minamo/model/vision.py new file mode 100644 index 0000000..a8fb99c --- /dev/null +++ b/minamo/model/vision.py @@ -0,0 +1,71 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class DualAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + # 空间注意力 + self.spatial = nn.Sequential( + nn.Conv2d(in_channels, 1, 1), + nn.Sigmoid() + ) + # 通道注意力 + self.channel = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, in_channels//8, 1), + nn.ReLU(), + nn.Conv2d(in_channels//8, in_channels, 1), + nn.Sigmoid() + ) + + def forward(self, x): + return x * self.spatial(x) + x * self.channel(x) + +class MinamoVisionModel(nn.Module): + def __init__(self, tile_types=32, embedding_dim=16, conv_channels=16): + super().__init__() + # 嵌入层处理不同图块类型 + self.embedding = nn.Embedding(tile_types, embedding_dim) + + # 卷积部分 + self.vision_conv = nn.Sequential( + nn.Conv2d(embedding_dim, conv_channels, 3, padding=1), + DualAttention(conv_channels), + nn.BatchNorm2d(conv_channels), + nn.ReLU(), + + nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1), + DualAttention(conv_channels*2), + nn.BatchNorm2d(conv_channels*2), + nn.ReLU(), + + nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1), + DualAttention(conv_channels*4), + nn.BatchNorm2d(conv_channels*4), + nn.ReLU(), + + nn.AdaptiveAvgPool2d(1) + ) + + # 预测头 + self.vision_head = nn.Sequential( + nn.Linear(conv_channels*4, conv_channels*2), + nn.Dropout(0.4), + nn.Linear(conv_channels*2, 1), + nn.Sigmoid() + ) + + def forward(self, map1, map2): + e1 = self.embedding(map1).permute(0, 3, 1, 2) + e2 = self.embedding(map2).permute(0, 3, 1, 2) + + v1 = self.vision_conv(e1) + v2 = self.vision_conv(e2) + + v1 = v1.view(v1.size(0), -1) # 展平 + v2 = v2.view(v2.size(0), -1) # 展平 + + vision_sim = self.vision_head(torch.abs(v1 - v2)) + + return vision_sim \ No newline at end of file diff --git a/minamo/train.py b/minamo/train.py index 28ff107..729ae93 100644 --- a/minamo/train.py +++ b/minamo/train.py @@ -2,7 +2,7 @@ import os from datetime import datetime import torch import torch.optim as optim -from torch.utils.data import DataLoader +from torch_geometric.loader import DataLoader from tqdm import tqdm from .model.model import MinamoModel from .model.loss import MinamoLoss @@ -63,17 +63,19 @@ def train(): for batch in dataloader: # 数据迁移到设备 - map1, map2, vision_simi, topo_simi = batch + map1, map2, vision_simi, topo_simi, graph1, graph2 = batch map1 = map1.to(device) # 转为 [B, C, H, W] map2 = map2.to(device) topo_simi = topo_simi.to(device) vision_simi = vision_simi.to(device) + graph1 = graph1.to(device) + graph2 = graph2.to(device) # print(map1.shape, map2.shape) # 前向传播 optimizer.zero_grad() - vision_pred, topo_pred = model(map1, map2) + vision_pred, topo_pred = model(map1, map2, graph1, graph2) # 计算损失 loss = criterion(vision_pred, topo_pred, vision_simi, topo_simi) @@ -103,13 +105,15 @@ def train(): val_loss = 0 with torch.no_grad(): for val_batch in val_loader: - map1_val, map2_val, vision_simi_val, topo_simi_val = val_batch + map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = val_batch map1_val = map1_val.to(device) map2_val = map2_val.to(device) vision_simi_val = vision_simi_val.to(device) topo_simi_val = topo_simi_val.to(device) + graph1 = graph1.to(device) + graph2 = graph2.to(device) - vision_pred_val, topo_pred_val = model(map1_val, map2_val) + vision_pred_val, topo_pred_val = model(map1_val, map2_val, graph1, graph2) loss_val = criterion( vision_pred_val, topo_pred_val, vision_simi_val, topo_simi_val diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f09e17c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +torch +torchvision +torchaudio +tqdm +torch-geometric +transformers \ No newline at end of file