mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-21 02:11:13 +08:00
refactor: Minamo Model 改为 softmax 输入
This commit is contained in:
parent
452df38866
commit
cb9e67dff7
@ -1,7 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from shared.graph import convert_map_to_graph
|
from shared.graph import convert_soft_map_to_graph
|
||||||
|
|
||||||
def load_data(path: str):
|
def load_data(path: str):
|
||||||
with open(path, 'r', encoding="utf-8") as f:
|
with open(path, 'r', encoding="utf-8") as f:
|
||||||
@ -22,11 +23,18 @@ class MinamoDataset(Dataset):
|
|||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
item = self.data[idx]
|
item = self.data[idx]
|
||||||
|
|
||||||
|
map1_probs = F.one_hot(torch.LongTensor(item['map1']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||||
|
map2_probs = F.one_hot(torch.LongTensor(item['map2']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||||
|
|
||||||
|
graph1 = convert_soft_map_to_graph(map1_probs)
|
||||||
|
graph2 = convert_soft_map_to_graph(map2_probs)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
torch.LongTensor(item['map1']),
|
map1_probs,
|
||||||
torch.LongTensor(item['map2']),
|
map2_probs,
|
||||||
torch.FloatTensor([item['visionSimilarity']]),
|
torch.FloatTensor([item['visionSimilarity']]),
|
||||||
torch.FloatTensor([item['topoSimilarity']]),
|
torch.FloatTensor([item['topoSimilarity']]),
|
||||||
convert_map_to_graph(item['map1']),
|
graph1,
|
||||||
convert_map_to_graph(item['map2'])
|
graph2
|
||||||
)
|
)
|
||||||
|
|||||||
@ -9,8 +9,8 @@ class MinamoTopoModel(nn.Module):
|
|||||||
self, tile_types=32, emb_dim=64, hidden_dim=64, out_dim=512, mlp_dim=128
|
self, tile_types=32, emb_dim=64, hidden_dim=64, out_dim=512, mlp_dim=128
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# 嵌入层
|
# 传入 softmax 概率值,直接映射
|
||||||
self.embedding = torch.nn.Embedding(tile_types, emb_dim)
|
self.input_proj = torch.nn.Linear(tile_types, emb_dim)
|
||||||
# 图卷积层
|
# 图卷积层
|
||||||
self.conv1 = GATConv(emb_dim, hidden_dim*2, heads=8, dropout=0.2)
|
self.conv1 = GATConv(emb_dim, hidden_dim*2, heads=8, dropout=0.2)
|
||||||
self.conv2 = GATConv(hidden_dim*16, hidden_dim*4, heads=4)
|
self.conv2 = GATConv(hidden_dim*16, hidden_dim*4, heads=4)
|
||||||
@ -31,7 +31,7 @@ class MinamoTopoModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, graph: Data):
|
def forward(self, graph: Data):
|
||||||
x = self.embedding(graph.x)
|
x = self.input_proj(graph.x)
|
||||||
# identity = x
|
# identity = x
|
||||||
|
|
||||||
x = self.conv1(x, graph.edge_index)
|
x = self.conv1(x, graph.edge_index)
|
||||||
|
|||||||
@ -4,30 +4,30 @@ import torch.nn.functional as F
|
|||||||
from shared.attention import CBAM
|
from shared.attention import CBAM
|
||||||
|
|
||||||
class MinamoVisionModel(nn.Module):
|
class MinamoVisionModel(nn.Module):
|
||||||
def __init__(self, tile_types=32, embedding_dim=32, conv_channels=64, out_dim=128):
|
def __init__(self, tile_types=32, conv_ch=32, out_dim=128):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# 嵌入层处理不同图块类型
|
# 输入 softmax 概率值
|
||||||
self.embedding = nn.Embedding(tile_types, embedding_dim)
|
self.input_conv = nn.Conv2d(tile_types, conv_ch, 3, padding=1)
|
||||||
|
|
||||||
# 卷积部分
|
# 卷积部分
|
||||||
self.vision_conv = nn.Sequential(
|
self.vision_conv = nn.Sequential(
|
||||||
nn.Conv2d(embedding_dim, conv_channels, 3, padding=1),
|
nn.Conv2d(conv_ch, conv_ch*2, 3, padding=1),
|
||||||
nn.BatchNorm2d(conv_channels),
|
nn.BatchNorm2d(conv_ch*2),
|
||||||
CBAM(conv_channels),
|
CBAM(conv_ch*2),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.MaxPool2d(2),
|
nn.MaxPool2d(2),
|
||||||
nn.Dropout2d(0.4),
|
nn.Dropout2d(0.4),
|
||||||
|
|
||||||
nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1),
|
nn.Conv2d(conv_ch*2, conv_ch*4, 3, padding=1),
|
||||||
nn.BatchNorm2d(conv_channels*2),
|
nn.BatchNorm2d(conv_ch*4),
|
||||||
CBAM(conv_channels*2),
|
CBAM(conv_ch*4),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.MaxPool2d(2),
|
nn.MaxPool2d(2),
|
||||||
nn.Dropout2d(0.4),
|
nn.Dropout2d(0.4),
|
||||||
|
|
||||||
nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1),
|
nn.Conv2d(conv_ch*4, conv_ch*8, 3, padding=1),
|
||||||
nn.BatchNorm2d(conv_channels*4),
|
nn.BatchNorm2d(conv_ch*8),
|
||||||
CBAM(conv_channels*4),
|
CBAM(conv_ch*8),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
|
|
||||||
nn.AdaptiveMaxPool2d(1)
|
nn.AdaptiveMaxPool2d(1)
|
||||||
@ -36,13 +36,11 @@ class MinamoVisionModel(nn.Module):
|
|||||||
# 输出为向量
|
# 输出为向量
|
||||||
self.vision_head = nn.Sequential(
|
self.vision_head = nn.Sequential(
|
||||||
nn.Dropout(0.4),
|
nn.Dropout(0.4),
|
||||||
nn.Linear(conv_channels*4, out_dim)
|
nn.Linear(conv_ch*8, out_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, map):
|
def forward(self, map):
|
||||||
x = self.embedding(map)
|
x = self.input_conv(map)
|
||||||
x = x.permute(0, 3, 1, 2)
|
|
||||||
|
|
||||||
x = self.vision_conv(x)
|
x = self.vision_conv(x)
|
||||||
x = x.view(x.size(0), -1) # 展平
|
x = x.view(x.size(0), -1) # 展平
|
||||||
|
|
||||||
|
|||||||
120
shared/graph.py
120
shared/graph.py
@ -2,7 +2,37 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch_geometric.data import Data, Batch
|
from torch_geometric.data import Data, Batch
|
||||||
from torch_geometric.utils import dense_to_sparse
|
|
||||||
|
def convert_soft_map_to_graph(map_probs):
|
||||||
|
"""
|
||||||
|
直接使用 Softmax 概率构建 soft 图结构
|
||||||
|
"""
|
||||||
|
C, H, W = map_probs.shape # [32, H, W]
|
||||||
|
N = H * W
|
||||||
|
device = map_probs.device
|
||||||
|
|
||||||
|
# 计算 soft 节点特征
|
||||||
|
node_features = map_probs.view(C, N).T # [N, C]
|
||||||
|
|
||||||
|
# 计算 soft 邻接边(基于 soft 权重)
|
||||||
|
edge_list = []
|
||||||
|
for r in range(H):
|
||||||
|
for c in range(W):
|
||||||
|
node = r * W + c
|
||||||
|
if c + 1 < W:
|
||||||
|
right = node + 1
|
||||||
|
edge_list.append([node, right])
|
||||||
|
if r + 1 < H:
|
||||||
|
down = node + W
|
||||||
|
edge_list.append([node, down])
|
||||||
|
|
||||||
|
edge_index = torch.tensor(edge_list).t().to(device)
|
||||||
|
|
||||||
|
# 计算 soft 边权重(基于 Softmax 概率)
|
||||||
|
soft_edge_weight = (map_probs[:, edge_index[0] // W, edge_index[0] % W] +
|
||||||
|
map_probs[:, edge_index[1] // W, edge_index[1] % W]) / 2
|
||||||
|
|
||||||
|
return Data(x=node_features, edge_index=edge_index, edge_attr=soft_edge_weight)
|
||||||
|
|
||||||
def convert_map_to_graph(map):
|
def convert_map_to_graph(map):
|
||||||
rows = len(map)
|
rows = len(map)
|
||||||
@ -31,68 +61,16 @@ def convert_map_to_graph(map):
|
|||||||
|
|
||||||
return Data(x=node_features, edge_index=edge_index)
|
return Data(x=node_features, edge_index=edge_index)
|
||||||
|
|
||||||
def soft_convert_map_to_graph(map_tensor, tau=0.5, threshold=0.1):
|
|
||||||
"""
|
|
||||||
将地图批量转换为 GNN 可用的 Graph Data,使用 soft 策略
|
|
||||||
"""
|
|
||||||
B, H, W, C = map_tensor.shape
|
|
||||||
N = H * W
|
|
||||||
|
|
||||||
# 使用 Gumbel-Softmax 确保是 one-hot 形式
|
|
||||||
y = F.gumbel_softmax(map_tensor.view(B, N, C), tau=tau, hard=True) # [B, N, C]
|
|
||||||
|
|
||||||
# 计算整数索引(用于 embedding)
|
|
||||||
node_features = y.argmax(dim=-1).long().unsqueeze(-1) # [B, N, 1]
|
|
||||||
|
|
||||||
# 取出墙体的 soft mask
|
|
||||||
wall_mask = y[:, :, 1] # [B, N]
|
|
||||||
|
|
||||||
adjacency_matrix = torch.zeros(B, N, N, device=map_tensor.device)
|
|
||||||
|
|
||||||
def get_index(r, c):
|
|
||||||
return r * W + c
|
|
||||||
|
|
||||||
for r in range(H):
|
|
||||||
for c in range(W):
|
|
||||||
idx = get_index(r, c)
|
|
||||||
if c + 1 < W:
|
|
||||||
right_idx = get_index(r, c + 1)
|
|
||||||
adjacency_matrix[:, idx, right_idx] = (1 - wall_mask[:, idx]) * (1 - wall_mask[:, right_idx])
|
|
||||||
if r + 1 < H:
|
|
||||||
down_idx = get_index(r + 1, c)
|
|
||||||
adjacency_matrix[:, idx, down_idx] = (1 - wall_mask[:, idx]) * (1 - wall_mask[:, down_idx])
|
|
||||||
|
|
||||||
edge_index_list, edge_weight_list = [], []
|
|
||||||
for b in range(B):
|
|
||||||
adj_bin = (adjacency_matrix[b] > threshold).float()
|
|
||||||
edge_index = torch.nonzero(adj_bin, as_tuple=False).T # [2, E]
|
|
||||||
edge_weight = adjacency_matrix[b][edge_index[0], edge_index[1]]
|
|
||||||
|
|
||||||
edge_index_list.append(edge_index)
|
|
||||||
edge_weight_list.append(edge_weight)
|
|
||||||
|
|
||||||
edge_index = torch.cat(edge_index_list, dim=1)
|
|
||||||
edge_weight = torch.cat(edge_weight_list, dim=0)
|
|
||||||
|
|
||||||
return Data(x=node_features, edge_index=edge_index, edge_attr=edge_weight)
|
|
||||||
|
|
||||||
class DynamicGraphConverter(nn.Module):
|
class DynamicGraphConverter(nn.Module):
|
||||||
def __init__(self, map_size=13):
|
def __init__(self, map_size=13):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.map_size = map_size
|
self.map_size = map_size
|
||||||
self.n_nodes = map_size * map_size
|
self.n_nodes = map_size * map_size
|
||||||
|
|
||||||
# 预计算所有可能的边索引组合(包括对角线)
|
|
||||||
self.base_edge_index = self._precompute_base_edges()
|
self.base_edge_index = self._precompute_base_edges()
|
||||||
|
|
||||||
def _precompute_base_edges(self):
|
def _precompute_base_edges(self):
|
||||||
"""预生成全连接边索引(包含所有可能邻接)"""
|
|
||||||
edge_list = []
|
edge_list = []
|
||||||
directions = [
|
directions = [(0, 1), (1, 0)]
|
||||||
(0, 1), # 右
|
|
||||||
(1, 0), # 下
|
|
||||||
]
|
|
||||||
|
|
||||||
for r in range(self.map_size):
|
for r in range(self.map_size):
|
||||||
for c in range(self.map_size):
|
for c in range(self.map_size):
|
||||||
node = r * self.map_size + c
|
node = r * self.map_size + c
|
||||||
@ -101,35 +79,31 @@ class DynamicGraphConverter(nn.Module):
|
|||||||
if 0 <= nr < self.map_size and 0 <= nc < self.map_size:
|
if 0 <= nr < self.map_size and 0 <= nc < self.map_size:
|
||||||
neighbor = nr * self.map_size + nc
|
neighbor = nr * self.map_size + nc
|
||||||
edge_list.append([node, neighbor])
|
edge_list.append([node, neighbor])
|
||||||
|
|
||||||
return torch.tensor(edge_list).t().contiguous().unique(dim=1)
|
return torch.tensor(edge_list).t().contiguous().unique(dim=1)
|
||||||
|
|
||||||
def forward(self, map_probs, tau=0.5):
|
def forward(self, map_probs, tau=0.5):
|
||||||
B, C, H, W = map_probs.shape
|
B, C, H, W = map_probs.shape
|
||||||
device = map_probs.device
|
device = map_probs.device
|
||||||
|
|
||||||
self.base_edge_index = self.base_edge_index.to(device)
|
self.base_edge_index = self.base_edge_index.to(device)
|
||||||
|
|
||||||
# 1. 节点特征离散化(保持可导)
|
# 1. 计算可微的节点 ID
|
||||||
node_logits = map_probs.view(B, C, -1).permute(0, 2, 1) # [B, N, C]
|
node_logits = map_probs.view(B, C, -1).permute(0, 2, 1) # [B, N, C]
|
||||||
hard_nodes = F.gumbel_softmax(node_logits, tau=tau, hard=True)
|
hard_nodes = F.gumbel_softmax(node_logits, tau=tau, hard=True)
|
||||||
node_ids = hard_nodes.argmax(dim=-1) # [B, N]
|
node_ids = (hard_nodes * torch.arange(C, device=device).view(1, 1, -1)).sum(dim=-1).long()
|
||||||
|
|
||||||
# 2. 动态边权重计算
|
# 2. 计算 soft 壁障 mask
|
||||||
wall_mask = (node_ids == 1).float() # 假设类别1是墙体
|
wall_mask = torch.sigmoid((node_ids - 1) * 10) # 类别 1 代表墙体,soft 处理
|
||||||
edge_weights = self._compute_dynamic_weights(wall_mask)
|
edge_weights = self._compute_dynamic_weights(wall_mask)
|
||||||
|
|
||||||
# 3. 构建动态图
|
# 3. 构建动态图
|
||||||
batch_data = []
|
batch_data = []
|
||||||
for b in range(B):
|
for b in range(B):
|
||||||
# 动态过滤无效边(与墙体相连的边)
|
soft_mask = torch.sigmoid((edge_weights[b] - 0.1) * 10) # 软门控
|
||||||
valid_mask = (edge_weights[b] > 0.1).squeeze(-1)
|
dynamic_edge_attr = edge_weights[b] * soft_mask # 仍然保留梯度
|
||||||
dynamic_edge_index = self.base_edge_index[:, valid_mask]
|
|
||||||
dynamic_edge_attr = edge_weights[b][valid_mask]
|
|
||||||
|
|
||||||
data = Data(
|
data = Data(
|
||||||
x=node_ids[b],
|
x=node_ids[b],
|
||||||
edge_index=dynamic_edge_index,
|
edge_index=self.base_edge_index,
|
||||||
edge_attr=dynamic_edge_attr
|
edge_attr=dynamic_edge_attr
|
||||||
)
|
)
|
||||||
batch_data.append(data)
|
batch_data.append(data)
|
||||||
@ -137,15 +111,9 @@ class DynamicGraphConverter(nn.Module):
|
|||||||
return Batch.from_data_list(batch_data)
|
return Batch.from_data_list(batch_data)
|
||||||
|
|
||||||
def _compute_dynamic_weights(self, wall_mask):
|
def _compute_dynamic_weights(self, wall_mask):
|
||||||
"""基于墙体存在性计算动态边权重"""
|
src_nodes = self.base_edge_index[0]
|
||||||
# wall_mask: [B, N]
|
dst_nodes = self.base_edge_index[1]
|
||||||
src_nodes = self.base_edge_index[0] # [E]
|
|
||||||
dst_nodes = self.base_edge_index[1] # [E]
|
|
||||||
|
|
||||||
# 边权重 = 1 - (源是墙 OR 目标墙)
|
# 让梯度能正确回传
|
||||||
weights = 1 - torch.logical_or(
|
weights = 1 - (wall_mask[:, src_nodes] + wall_mask[:, dst_nodes]) / 2
|
||||||
wall_mask[:, src_nodes],
|
return weights.unsqueeze(-1)
|
||||||
wall_mask[:, dst_nodes]
|
|
||||||
).float() # [B, E]
|
|
||||||
|
|
||||||
return weights.unsqueeze(-1) # [B, E, 1]
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user