diff --git a/minamo/dataset.py b/minamo/dataset.py index afd253c..0c55387 100644 --- a/minamo/dataset.py +++ b/minamo/dataset.py @@ -1,7 +1,8 @@ import json import torch +import torch.nn.functional as F from torch.utils.data import Dataset -from shared.graph import convert_map_to_graph +from shared.graph import convert_soft_map_to_graph def load_data(path: str): with open(path, 'r', encoding="utf-8") as f: @@ -22,11 +23,18 @@ class MinamoDataset(Dataset): def __getitem__(self, idx): item = self.data[idx] + + map1_probs = F.one_hot(torch.LongTensor(item['map1']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] + map2_probs = F.one_hot(torch.LongTensor(item['map2']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] + + graph1 = convert_soft_map_to_graph(map1_probs) + graph2 = convert_soft_map_to_graph(map2_probs) + return ( - torch.LongTensor(item['map1']), - torch.LongTensor(item['map2']), + map1_probs, + map2_probs, torch.FloatTensor([item['visionSimilarity']]), torch.FloatTensor([item['topoSimilarity']]), - convert_map_to_graph(item['map1']), - convert_map_to_graph(item['map2']) + graph1, + graph2 ) diff --git a/minamo/model/topo.py b/minamo/model/topo.py index d0910e5..9a4a00e 100644 --- a/minamo/model/topo.py +++ b/minamo/model/topo.py @@ -9,8 +9,8 @@ class MinamoTopoModel(nn.Module): self, tile_types=32, emb_dim=64, hidden_dim=64, out_dim=512, mlp_dim=128 ): super().__init__() - # 嵌入层 - self.embedding = torch.nn.Embedding(tile_types, emb_dim) + # 传入 softmax 概率值,直接映射 + self.input_proj = torch.nn.Linear(tile_types, emb_dim) # 图卷积层 self.conv1 = GATConv(emb_dim, hidden_dim*2, heads=8, dropout=0.2) self.conv2 = GATConv(hidden_dim*16, hidden_dim*4, heads=4) @@ -31,7 +31,7 @@ class MinamoTopoModel(nn.Module): ) def forward(self, graph: Data): - x = self.embedding(graph.x) + x = self.input_proj(graph.x) # identity = x x = self.conv1(x, graph.edge_index) diff --git a/minamo/model/vision.py b/minamo/model/vision.py index 4f6c8bb..fc39a16 100644 --- a/minamo/model/vision.py +++ b/minamo/model/vision.py @@ -4,30 +4,30 @@ import torch.nn.functional as F from shared.attention import CBAM class MinamoVisionModel(nn.Module): - def __init__(self, tile_types=32, embedding_dim=32, conv_channels=64, out_dim=128): + def __init__(self, tile_types=32, conv_ch=32, out_dim=128): super().__init__() - # 嵌入层处理不同图块类型 - self.embedding = nn.Embedding(tile_types, embedding_dim) + # 输入 softmax 概率值 + self.input_conv = nn.Conv2d(tile_types, conv_ch, 3, padding=1) # 卷积部分 self.vision_conv = nn.Sequential( - nn.Conv2d(embedding_dim, conv_channels, 3, padding=1), - nn.BatchNorm2d(conv_channels), - CBAM(conv_channels), + nn.Conv2d(conv_ch, conv_ch*2, 3, padding=1), + nn.BatchNorm2d(conv_ch*2), + CBAM(conv_ch*2), nn.GELU(), nn.MaxPool2d(2), nn.Dropout2d(0.4), - nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1), - nn.BatchNorm2d(conv_channels*2), - CBAM(conv_channels*2), + nn.Conv2d(conv_ch*2, conv_ch*4, 3, padding=1), + nn.BatchNorm2d(conv_ch*4), + CBAM(conv_ch*4), nn.GELU(), nn.MaxPool2d(2), nn.Dropout2d(0.4), - nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1), - nn.BatchNorm2d(conv_channels*4), - CBAM(conv_channels*4), + nn.Conv2d(conv_ch*4, conv_ch*8, 3, padding=1), + nn.BatchNorm2d(conv_ch*8), + CBAM(conv_ch*8), nn.GELU(), nn.AdaptiveMaxPool2d(1) @@ -36,13 +36,11 @@ class MinamoVisionModel(nn.Module): # 输出为向量 self.vision_head = nn.Sequential( nn.Dropout(0.4), - nn.Linear(conv_channels*4, out_dim) + nn.Linear(conv_ch*8, out_dim) ) def forward(self, map): - x = self.embedding(map) - x = x.permute(0, 3, 1, 2) - + x = self.input_conv(map) x = self.vision_conv(x) x = x.view(x.size(0), -1) # 展平 diff --git a/shared/graph.py b/shared/graph.py index 6dc9e3e..17893b6 100644 --- a/shared/graph.py +++ b/shared/graph.py @@ -2,7 +2,37 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.data import Data, Batch -from torch_geometric.utils import dense_to_sparse + +def convert_soft_map_to_graph(map_probs): + """ + 直接使用 Softmax 概率构建 soft 图结构 + """ + C, H, W = map_probs.shape # [32, H, W] + N = H * W + device = map_probs.device + + # 计算 soft 节点特征 + node_features = map_probs.view(C, N).T # [N, C] + + # 计算 soft 邻接边(基于 soft 权重) + edge_list = [] + for r in range(H): + for c in range(W): + node = r * W + c + if c + 1 < W: + right = node + 1 + edge_list.append([node, right]) + if r + 1 < H: + down = node + W + edge_list.append([node, down]) + + edge_index = torch.tensor(edge_list).t().to(device) + + # 计算 soft 边权重(基于 Softmax 概率) + soft_edge_weight = (map_probs[:, edge_index[0] // W, edge_index[0] % W] + + map_probs[:, edge_index[1] // W, edge_index[1] % W]) / 2 + + return Data(x=node_features, edge_index=edge_index, edge_attr=soft_edge_weight) def convert_map_to_graph(map): rows = len(map) @@ -31,68 +61,16 @@ def convert_map_to_graph(map): return Data(x=node_features, edge_index=edge_index) -def soft_convert_map_to_graph(map_tensor, tau=0.5, threshold=0.1): - """ - 将地图批量转换为 GNN 可用的 Graph Data,使用 soft 策略 - """ - B, H, W, C = map_tensor.shape - N = H * W - - # 使用 Gumbel-Softmax 确保是 one-hot 形式 - y = F.gumbel_softmax(map_tensor.view(B, N, C), tau=tau, hard=True) # [B, N, C] - - # 计算整数索引(用于 embedding) - node_features = y.argmax(dim=-1).long().unsqueeze(-1) # [B, N, 1] - - # 取出墙体的 soft mask - wall_mask = y[:, :, 1] # [B, N] - - adjacency_matrix = torch.zeros(B, N, N, device=map_tensor.device) - - def get_index(r, c): - return r * W + c - - for r in range(H): - for c in range(W): - idx = get_index(r, c) - if c + 1 < W: - right_idx = get_index(r, c + 1) - adjacency_matrix[:, idx, right_idx] = (1 - wall_mask[:, idx]) * (1 - wall_mask[:, right_idx]) - if r + 1 < H: - down_idx = get_index(r + 1, c) - adjacency_matrix[:, idx, down_idx] = (1 - wall_mask[:, idx]) * (1 - wall_mask[:, down_idx]) - - edge_index_list, edge_weight_list = [], [] - for b in range(B): - adj_bin = (adjacency_matrix[b] > threshold).float() - edge_index = torch.nonzero(adj_bin, as_tuple=False).T # [2, E] - edge_weight = adjacency_matrix[b][edge_index[0], edge_index[1]] - - edge_index_list.append(edge_index) - edge_weight_list.append(edge_weight) - - edge_index = torch.cat(edge_index_list, dim=1) - edge_weight = torch.cat(edge_weight_list, dim=0) - - return Data(x=node_features, edge_index=edge_index, edge_attr=edge_weight) - class DynamicGraphConverter(nn.Module): def __init__(self, map_size=13): super().__init__() self.map_size = map_size self.n_nodes = map_size * map_size - - # 预计算所有可能的边索引组合(包括对角线) self.base_edge_index = self._precompute_base_edges() - + def _precompute_base_edges(self): - """预生成全连接边索引(包含所有可能邻接)""" edge_list = [] - directions = [ - (0, 1), # 右 - (1, 0), # 下 - ] - + directions = [(0, 1), (1, 0)] for r in range(self.map_size): for c in range(self.map_size): node = r * self.map_size + c @@ -101,51 +79,41 @@ class DynamicGraphConverter(nn.Module): if 0 <= nr < self.map_size and 0 <= nc < self.map_size: neighbor = nr * self.map_size + nc edge_list.append([node, neighbor]) - return torch.tensor(edge_list).t().contiguous().unique(dim=1) def forward(self, map_probs, tau=0.5): B, C, H, W = map_probs.shape device = map_probs.device - self.base_edge_index = self.base_edge_index.to(device) - - # 1. 节点特征离散化(保持可导) + + # 1. 计算可微的节点 ID node_logits = map_probs.view(B, C, -1).permute(0, 2, 1) # [B, N, C] hard_nodes = F.gumbel_softmax(node_logits, tau=tau, hard=True) - node_ids = hard_nodes.argmax(dim=-1) # [B, N] + node_ids = (hard_nodes * torch.arange(C, device=device).view(1, 1, -1)).sum(dim=-1).long() - # 2. 动态边权重计算 - wall_mask = (node_ids == 1).float() # 假设类别1是墙体 + # 2. 计算 soft 壁障 mask + wall_mask = torch.sigmoid((node_ids - 1) * 10) # 类别 1 代表墙体,soft 处理 edge_weights = self._compute_dynamic_weights(wall_mask) # 3. 构建动态图 batch_data = [] for b in range(B): - # 动态过滤无效边(与墙体相连的边) - valid_mask = (edge_weights[b] > 0.1).squeeze(-1) - dynamic_edge_index = self.base_edge_index[:, valid_mask] - dynamic_edge_attr = edge_weights[b][valid_mask] + soft_mask = torch.sigmoid((edge_weights[b] - 0.1) * 10) # 软门控 + dynamic_edge_attr = edge_weights[b] * soft_mask # 仍然保留梯度 data = Data( x=node_ids[b], - edge_index=dynamic_edge_index, + edge_index=self.base_edge_index, edge_attr=dynamic_edge_attr ) batch_data.append(data) - + return Batch.from_data_list(batch_data) def _compute_dynamic_weights(self, wall_mask): - """基于墙体存在性计算动态边权重""" - # wall_mask: [B, N] - src_nodes = self.base_edge_index[0] # [E] - dst_nodes = self.base_edge_index[1] # [E] + src_nodes = self.base_edge_index[0] + dst_nodes = self.base_edge_index[1] - # 边权重 = 1 - (源是墙 OR 目标墙) - weights = 1 - torch.logical_or( - wall_mask[:, src_nodes], - wall_mask[:, dst_nodes] - ).float() # [B, E] - - return weights.unsqueeze(-1) # [B, E, 1] + # 让梯度能正确回传 + weights = 1 - (wall_mask[:, src_nodes] + wall_mask[:, dst_nodes]) / 2 + return weights.unsqueeze(-1)