refactor: Minamo Model 改为 softmax 输入

This commit is contained in:
unanmed 2025-03-19 21:13:51 +08:00
parent 452df38866
commit cb9e67dff7
4 changed files with 77 additions and 103 deletions

View File

@ -1,7 +1,8 @@
import json
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from shared.graph import convert_map_to_graph
from shared.graph import convert_soft_map_to_graph
def load_data(path: str):
with open(path, 'r', encoding="utf-8") as f:
@ -22,11 +23,18 @@ class MinamoDataset(Dataset):
def __getitem__(self, idx):
item = self.data[idx]
map1_probs = F.one_hot(torch.LongTensor(item['map1']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
map2_probs = F.one_hot(torch.LongTensor(item['map2']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
graph1 = convert_soft_map_to_graph(map1_probs)
graph2 = convert_soft_map_to_graph(map2_probs)
return (
torch.LongTensor(item['map1']),
torch.LongTensor(item['map2']),
map1_probs,
map2_probs,
torch.FloatTensor([item['visionSimilarity']]),
torch.FloatTensor([item['topoSimilarity']]),
convert_map_to_graph(item['map1']),
convert_map_to_graph(item['map2'])
graph1,
graph2
)

View File

@ -9,8 +9,8 @@ class MinamoTopoModel(nn.Module):
self, tile_types=32, emb_dim=64, hidden_dim=64, out_dim=512, mlp_dim=128
):
super().__init__()
# 嵌入层
self.embedding = torch.nn.Embedding(tile_types, emb_dim)
# 传入 softmax 概率值,直接映射
self.input_proj = torch.nn.Linear(tile_types, emb_dim)
# 图卷积层
self.conv1 = GATConv(emb_dim, hidden_dim*2, heads=8, dropout=0.2)
self.conv2 = GATConv(hidden_dim*16, hidden_dim*4, heads=4)
@ -31,7 +31,7 @@ class MinamoTopoModel(nn.Module):
)
def forward(self, graph: Data):
x = self.embedding(graph.x)
x = self.input_proj(graph.x)
# identity = x
x = self.conv1(x, graph.edge_index)

View File

@ -4,30 +4,30 @@ import torch.nn.functional as F
from shared.attention import CBAM
class MinamoVisionModel(nn.Module):
def __init__(self, tile_types=32, embedding_dim=32, conv_channels=64, out_dim=128):
def __init__(self, tile_types=32, conv_ch=32, out_dim=128):
super().__init__()
# 嵌入层处理不同图块类型
self.embedding = nn.Embedding(tile_types, embedding_dim)
# 输入 softmax 概率值
self.input_conv = nn.Conv2d(tile_types, conv_ch, 3, padding=1)
# 卷积部分
self.vision_conv = nn.Sequential(
nn.Conv2d(embedding_dim, conv_channels, 3, padding=1),
nn.BatchNorm2d(conv_channels),
CBAM(conv_channels),
nn.Conv2d(conv_ch, conv_ch*2, 3, padding=1),
nn.BatchNorm2d(conv_ch*2),
CBAM(conv_ch*2),
nn.GELU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.4),
nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1),
nn.BatchNorm2d(conv_channels*2),
CBAM(conv_channels*2),
nn.Conv2d(conv_ch*2, conv_ch*4, 3, padding=1),
nn.BatchNorm2d(conv_ch*4),
CBAM(conv_ch*4),
nn.GELU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.4),
nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1),
nn.BatchNorm2d(conv_channels*4),
CBAM(conv_channels*4),
nn.Conv2d(conv_ch*4, conv_ch*8, 3, padding=1),
nn.BatchNorm2d(conv_ch*8),
CBAM(conv_ch*8),
nn.GELU(),
nn.AdaptiveMaxPool2d(1)
@ -36,13 +36,11 @@ class MinamoVisionModel(nn.Module):
# 输出为向量
self.vision_head = nn.Sequential(
nn.Dropout(0.4),
nn.Linear(conv_channels*4, out_dim)
nn.Linear(conv_ch*8, out_dim)
)
def forward(self, map):
x = self.embedding(map)
x = x.permute(0, 3, 1, 2)
x = self.input_conv(map)
x = self.vision_conv(x)
x = x.view(x.size(0), -1) # 展平

View File

@ -2,7 +2,37 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
from torch_geometric.utils import dense_to_sparse
def convert_soft_map_to_graph(map_probs):
"""
直接使用 Softmax 概率构建 soft 图结构
"""
C, H, W = map_probs.shape # [32, H, W]
N = H * W
device = map_probs.device
# 计算 soft 节点特征
node_features = map_probs.view(C, N).T # [N, C]
# 计算 soft 邻接边(基于 soft 权重)
edge_list = []
for r in range(H):
for c in range(W):
node = r * W + c
if c + 1 < W:
right = node + 1
edge_list.append([node, right])
if r + 1 < H:
down = node + W
edge_list.append([node, down])
edge_index = torch.tensor(edge_list).t().to(device)
# 计算 soft 边权重(基于 Softmax 概率)
soft_edge_weight = (map_probs[:, edge_index[0] // W, edge_index[0] % W] +
map_probs[:, edge_index[1] // W, edge_index[1] % W]) / 2
return Data(x=node_features, edge_index=edge_index, edge_attr=soft_edge_weight)
def convert_map_to_graph(map):
rows = len(map)
@ -31,68 +61,16 @@ def convert_map_to_graph(map):
return Data(x=node_features, edge_index=edge_index)
def soft_convert_map_to_graph(map_tensor, tau=0.5, threshold=0.1):
"""
将地图批量转换为 GNN 可用的 Graph Data使用 soft 策略
"""
B, H, W, C = map_tensor.shape
N = H * W
# 使用 Gumbel-Softmax 确保是 one-hot 形式
y = F.gumbel_softmax(map_tensor.view(B, N, C), tau=tau, hard=True) # [B, N, C]
# 计算整数索引(用于 embedding
node_features = y.argmax(dim=-1).long().unsqueeze(-1) # [B, N, 1]
# 取出墙体的 soft mask
wall_mask = y[:, :, 1] # [B, N]
adjacency_matrix = torch.zeros(B, N, N, device=map_tensor.device)
def get_index(r, c):
return r * W + c
for r in range(H):
for c in range(W):
idx = get_index(r, c)
if c + 1 < W:
right_idx = get_index(r, c + 1)
adjacency_matrix[:, idx, right_idx] = (1 - wall_mask[:, idx]) * (1 - wall_mask[:, right_idx])
if r + 1 < H:
down_idx = get_index(r + 1, c)
adjacency_matrix[:, idx, down_idx] = (1 - wall_mask[:, idx]) * (1 - wall_mask[:, down_idx])
edge_index_list, edge_weight_list = [], []
for b in range(B):
adj_bin = (adjacency_matrix[b] > threshold).float()
edge_index = torch.nonzero(adj_bin, as_tuple=False).T # [2, E]
edge_weight = adjacency_matrix[b][edge_index[0], edge_index[1]]
edge_index_list.append(edge_index)
edge_weight_list.append(edge_weight)
edge_index = torch.cat(edge_index_list, dim=1)
edge_weight = torch.cat(edge_weight_list, dim=0)
return Data(x=node_features, edge_index=edge_index, edge_attr=edge_weight)
class DynamicGraphConverter(nn.Module):
def __init__(self, map_size=13):
super().__init__()
self.map_size = map_size
self.n_nodes = map_size * map_size
# 预计算所有可能的边索引组合(包括对角线)
self.base_edge_index = self._precompute_base_edges()
def _precompute_base_edges(self):
"""预生成全连接边索引(包含所有可能邻接)"""
edge_list = []
directions = [
(0, 1), # 右
(1, 0), # 下
]
directions = [(0, 1), (1, 0)]
for r in range(self.map_size):
for c in range(self.map_size):
node = r * self.map_size + c
@ -101,51 +79,41 @@ class DynamicGraphConverter(nn.Module):
if 0 <= nr < self.map_size and 0 <= nc < self.map_size:
neighbor = nr * self.map_size + nc
edge_list.append([node, neighbor])
return torch.tensor(edge_list).t().contiguous().unique(dim=1)
def forward(self, map_probs, tau=0.5):
B, C, H, W = map_probs.shape
device = map_probs.device
self.base_edge_index = self.base_edge_index.to(device)
# 1. 节点特征离散化(保持可导)
# 1. 计算可微的节点 ID
node_logits = map_probs.view(B, C, -1).permute(0, 2, 1) # [B, N, C]
hard_nodes = F.gumbel_softmax(node_logits, tau=tau, hard=True)
node_ids = hard_nodes.argmax(dim=-1) # [B, N]
node_ids = (hard_nodes * torch.arange(C, device=device).view(1, 1, -1)).sum(dim=-1).long()
# 2. 动态边权重计算
wall_mask = (node_ids == 1).float() # 假设类别1是墙体
# 2. 计算 soft 壁障 mask
wall_mask = torch.sigmoid((node_ids - 1) * 10) # 类别 1 代表墙体soft 处理
edge_weights = self._compute_dynamic_weights(wall_mask)
# 3. 构建动态图
batch_data = []
for b in range(B):
# 动态过滤无效边(与墙体相连的边)
valid_mask = (edge_weights[b] > 0.1).squeeze(-1)
dynamic_edge_index = self.base_edge_index[:, valid_mask]
dynamic_edge_attr = edge_weights[b][valid_mask]
soft_mask = torch.sigmoid((edge_weights[b] - 0.1) * 10) # 软门控
dynamic_edge_attr = edge_weights[b] * soft_mask # 仍然保留梯度
data = Data(
x=node_ids[b],
edge_index=dynamic_edge_index,
edge_index=self.base_edge_index,
edge_attr=dynamic_edge_attr
)
batch_data.append(data)
return Batch.from_data_list(batch_data)
def _compute_dynamic_weights(self, wall_mask):
"""基于墙体存在性计算动态边权重"""
# wall_mask: [B, N]
src_nodes = self.base_edge_index[0] # [E]
dst_nodes = self.base_edge_index[1] # [E]
src_nodes = self.base_edge_index[0]
dst_nodes = self.base_edge_index[1]
# 边权重 = 1 - (源是墙 OR 目标墙)
weights = 1 - torch.logical_or(
wall_mask[:, src_nodes],
wall_mask[:, dst_nodes]
).float() # [B, E]
return weights.unsqueeze(-1) # [B, E, 1]
# 让梯度能正确回传
weights = 1 - (wall_mask[:, src_nodes] + wall_mask[:, dst_nodes]) / 2
return weights.unsqueeze(-1)