mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-17 06:51:11 +08:00
refactor: Minamo Model 改为 softmax 输入
This commit is contained in:
parent
452df38866
commit
cb9e67dff7
@ -1,7 +1,8 @@
|
||||
import json
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset
|
||||
from shared.graph import convert_map_to_graph
|
||||
from shared.graph import convert_soft_map_to_graph
|
||||
|
||||
def load_data(path: str):
|
||||
with open(path, 'r', encoding="utf-8") as f:
|
||||
@ -22,11 +23,18 @@ class MinamoDataset(Dataset):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
|
||||
map1_probs = F.one_hot(torch.LongTensor(item['map1']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
map2_probs = F.one_hot(torch.LongTensor(item['map2']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
|
||||
graph1 = convert_soft_map_to_graph(map1_probs)
|
||||
graph2 = convert_soft_map_to_graph(map2_probs)
|
||||
|
||||
return (
|
||||
torch.LongTensor(item['map1']),
|
||||
torch.LongTensor(item['map2']),
|
||||
map1_probs,
|
||||
map2_probs,
|
||||
torch.FloatTensor([item['visionSimilarity']]),
|
||||
torch.FloatTensor([item['topoSimilarity']]),
|
||||
convert_map_to_graph(item['map1']),
|
||||
convert_map_to_graph(item['map2'])
|
||||
graph1,
|
||||
graph2
|
||||
)
|
||||
|
||||
@ -9,8 +9,8 @@ class MinamoTopoModel(nn.Module):
|
||||
self, tile_types=32, emb_dim=64, hidden_dim=64, out_dim=512, mlp_dim=128
|
||||
):
|
||||
super().__init__()
|
||||
# 嵌入层
|
||||
self.embedding = torch.nn.Embedding(tile_types, emb_dim)
|
||||
# 传入 softmax 概率值,直接映射
|
||||
self.input_proj = torch.nn.Linear(tile_types, emb_dim)
|
||||
# 图卷积层
|
||||
self.conv1 = GATConv(emb_dim, hidden_dim*2, heads=8, dropout=0.2)
|
||||
self.conv2 = GATConv(hidden_dim*16, hidden_dim*4, heads=4)
|
||||
@ -31,7 +31,7 @@ class MinamoTopoModel(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, graph: Data):
|
||||
x = self.embedding(graph.x)
|
||||
x = self.input_proj(graph.x)
|
||||
# identity = x
|
||||
|
||||
x = self.conv1(x, graph.edge_index)
|
||||
|
||||
@ -4,30 +4,30 @@ import torch.nn.functional as F
|
||||
from shared.attention import CBAM
|
||||
|
||||
class MinamoVisionModel(nn.Module):
|
||||
def __init__(self, tile_types=32, embedding_dim=32, conv_channels=64, out_dim=128):
|
||||
def __init__(self, tile_types=32, conv_ch=32, out_dim=128):
|
||||
super().__init__()
|
||||
# 嵌入层处理不同图块类型
|
||||
self.embedding = nn.Embedding(tile_types, embedding_dim)
|
||||
# 输入 softmax 概率值
|
||||
self.input_conv = nn.Conv2d(tile_types, conv_ch, 3, padding=1)
|
||||
|
||||
# 卷积部分
|
||||
self.vision_conv = nn.Sequential(
|
||||
nn.Conv2d(embedding_dim, conv_channels, 3, padding=1),
|
||||
nn.BatchNorm2d(conv_channels),
|
||||
CBAM(conv_channels),
|
||||
nn.Conv2d(conv_ch, conv_ch*2, 3, padding=1),
|
||||
nn.BatchNorm2d(conv_ch*2),
|
||||
CBAM(conv_ch*2),
|
||||
nn.GELU(),
|
||||
nn.MaxPool2d(2),
|
||||
nn.Dropout2d(0.4),
|
||||
|
||||
nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1),
|
||||
nn.BatchNorm2d(conv_channels*2),
|
||||
CBAM(conv_channels*2),
|
||||
nn.Conv2d(conv_ch*2, conv_ch*4, 3, padding=1),
|
||||
nn.BatchNorm2d(conv_ch*4),
|
||||
CBAM(conv_ch*4),
|
||||
nn.GELU(),
|
||||
nn.MaxPool2d(2),
|
||||
nn.Dropout2d(0.4),
|
||||
|
||||
nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1),
|
||||
nn.BatchNorm2d(conv_channels*4),
|
||||
CBAM(conv_channels*4),
|
||||
nn.Conv2d(conv_ch*4, conv_ch*8, 3, padding=1),
|
||||
nn.BatchNorm2d(conv_ch*8),
|
||||
CBAM(conv_ch*8),
|
||||
nn.GELU(),
|
||||
|
||||
nn.AdaptiveMaxPool2d(1)
|
||||
@ -36,13 +36,11 @@ class MinamoVisionModel(nn.Module):
|
||||
# 输出为向量
|
||||
self.vision_head = nn.Sequential(
|
||||
nn.Dropout(0.4),
|
||||
nn.Linear(conv_channels*4, out_dim)
|
||||
nn.Linear(conv_ch*8, out_dim)
|
||||
)
|
||||
|
||||
def forward(self, map):
|
||||
x = self.embedding(map)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
|
||||
x = self.input_conv(map)
|
||||
x = self.vision_conv(x)
|
||||
x = x.view(x.size(0), -1) # 展平
|
||||
|
||||
|
||||
126
shared/graph.py
126
shared/graph.py
@ -2,7 +2,37 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.data import Data, Batch
|
||||
from torch_geometric.utils import dense_to_sparse
|
||||
|
||||
def convert_soft_map_to_graph(map_probs):
|
||||
"""
|
||||
直接使用 Softmax 概率构建 soft 图结构
|
||||
"""
|
||||
C, H, W = map_probs.shape # [32, H, W]
|
||||
N = H * W
|
||||
device = map_probs.device
|
||||
|
||||
# 计算 soft 节点特征
|
||||
node_features = map_probs.view(C, N).T # [N, C]
|
||||
|
||||
# 计算 soft 邻接边(基于 soft 权重)
|
||||
edge_list = []
|
||||
for r in range(H):
|
||||
for c in range(W):
|
||||
node = r * W + c
|
||||
if c + 1 < W:
|
||||
right = node + 1
|
||||
edge_list.append([node, right])
|
||||
if r + 1 < H:
|
||||
down = node + W
|
||||
edge_list.append([node, down])
|
||||
|
||||
edge_index = torch.tensor(edge_list).t().to(device)
|
||||
|
||||
# 计算 soft 边权重(基于 Softmax 概率)
|
||||
soft_edge_weight = (map_probs[:, edge_index[0] // W, edge_index[0] % W] +
|
||||
map_probs[:, edge_index[1] // W, edge_index[1] % W]) / 2
|
||||
|
||||
return Data(x=node_features, edge_index=edge_index, edge_attr=soft_edge_weight)
|
||||
|
||||
def convert_map_to_graph(map):
|
||||
rows = len(map)
|
||||
@ -31,68 +61,16 @@ def convert_map_to_graph(map):
|
||||
|
||||
return Data(x=node_features, edge_index=edge_index)
|
||||
|
||||
def soft_convert_map_to_graph(map_tensor, tau=0.5, threshold=0.1):
|
||||
"""
|
||||
将地图批量转换为 GNN 可用的 Graph Data,使用 soft 策略
|
||||
"""
|
||||
B, H, W, C = map_tensor.shape
|
||||
N = H * W
|
||||
|
||||
# 使用 Gumbel-Softmax 确保是 one-hot 形式
|
||||
y = F.gumbel_softmax(map_tensor.view(B, N, C), tau=tau, hard=True) # [B, N, C]
|
||||
|
||||
# 计算整数索引(用于 embedding)
|
||||
node_features = y.argmax(dim=-1).long().unsqueeze(-1) # [B, N, 1]
|
||||
|
||||
# 取出墙体的 soft mask
|
||||
wall_mask = y[:, :, 1] # [B, N]
|
||||
|
||||
adjacency_matrix = torch.zeros(B, N, N, device=map_tensor.device)
|
||||
|
||||
def get_index(r, c):
|
||||
return r * W + c
|
||||
|
||||
for r in range(H):
|
||||
for c in range(W):
|
||||
idx = get_index(r, c)
|
||||
if c + 1 < W:
|
||||
right_idx = get_index(r, c + 1)
|
||||
adjacency_matrix[:, idx, right_idx] = (1 - wall_mask[:, idx]) * (1 - wall_mask[:, right_idx])
|
||||
if r + 1 < H:
|
||||
down_idx = get_index(r + 1, c)
|
||||
adjacency_matrix[:, idx, down_idx] = (1 - wall_mask[:, idx]) * (1 - wall_mask[:, down_idx])
|
||||
|
||||
edge_index_list, edge_weight_list = [], []
|
||||
for b in range(B):
|
||||
adj_bin = (adjacency_matrix[b] > threshold).float()
|
||||
edge_index = torch.nonzero(adj_bin, as_tuple=False).T # [2, E]
|
||||
edge_weight = adjacency_matrix[b][edge_index[0], edge_index[1]]
|
||||
|
||||
edge_index_list.append(edge_index)
|
||||
edge_weight_list.append(edge_weight)
|
||||
|
||||
edge_index = torch.cat(edge_index_list, dim=1)
|
||||
edge_weight = torch.cat(edge_weight_list, dim=0)
|
||||
|
||||
return Data(x=node_features, edge_index=edge_index, edge_attr=edge_weight)
|
||||
|
||||
class DynamicGraphConverter(nn.Module):
|
||||
def __init__(self, map_size=13):
|
||||
super().__init__()
|
||||
self.map_size = map_size
|
||||
self.n_nodes = map_size * map_size
|
||||
|
||||
# 预计算所有可能的边索引组合(包括对角线)
|
||||
self.base_edge_index = self._precompute_base_edges()
|
||||
|
||||
|
||||
def _precompute_base_edges(self):
|
||||
"""预生成全连接边索引(包含所有可能邻接)"""
|
||||
edge_list = []
|
||||
directions = [
|
||||
(0, 1), # 右
|
||||
(1, 0), # 下
|
||||
]
|
||||
|
||||
directions = [(0, 1), (1, 0)]
|
||||
for r in range(self.map_size):
|
||||
for c in range(self.map_size):
|
||||
node = r * self.map_size + c
|
||||
@ -101,51 +79,41 @@ class DynamicGraphConverter(nn.Module):
|
||||
if 0 <= nr < self.map_size and 0 <= nc < self.map_size:
|
||||
neighbor = nr * self.map_size + nc
|
||||
edge_list.append([node, neighbor])
|
||||
|
||||
return torch.tensor(edge_list).t().contiguous().unique(dim=1)
|
||||
|
||||
def forward(self, map_probs, tau=0.5):
|
||||
B, C, H, W = map_probs.shape
|
||||
device = map_probs.device
|
||||
|
||||
self.base_edge_index = self.base_edge_index.to(device)
|
||||
|
||||
# 1. 节点特征离散化(保持可导)
|
||||
|
||||
# 1. 计算可微的节点 ID
|
||||
node_logits = map_probs.view(B, C, -1).permute(0, 2, 1) # [B, N, C]
|
||||
hard_nodes = F.gumbel_softmax(node_logits, tau=tau, hard=True)
|
||||
node_ids = hard_nodes.argmax(dim=-1) # [B, N]
|
||||
node_ids = (hard_nodes * torch.arange(C, device=device).view(1, 1, -1)).sum(dim=-1).long()
|
||||
|
||||
# 2. 动态边权重计算
|
||||
wall_mask = (node_ids == 1).float() # 假设类别1是墙体
|
||||
# 2. 计算 soft 壁障 mask
|
||||
wall_mask = torch.sigmoid((node_ids - 1) * 10) # 类别 1 代表墙体,soft 处理
|
||||
edge_weights = self._compute_dynamic_weights(wall_mask)
|
||||
|
||||
# 3. 构建动态图
|
||||
batch_data = []
|
||||
for b in range(B):
|
||||
# 动态过滤无效边(与墙体相连的边)
|
||||
valid_mask = (edge_weights[b] > 0.1).squeeze(-1)
|
||||
dynamic_edge_index = self.base_edge_index[:, valid_mask]
|
||||
dynamic_edge_attr = edge_weights[b][valid_mask]
|
||||
soft_mask = torch.sigmoid((edge_weights[b] - 0.1) * 10) # 软门控
|
||||
dynamic_edge_attr = edge_weights[b] * soft_mask # 仍然保留梯度
|
||||
|
||||
data = Data(
|
||||
x=node_ids[b],
|
||||
edge_index=dynamic_edge_index,
|
||||
edge_index=self.base_edge_index,
|
||||
edge_attr=dynamic_edge_attr
|
||||
)
|
||||
batch_data.append(data)
|
||||
|
||||
|
||||
return Batch.from_data_list(batch_data)
|
||||
|
||||
def _compute_dynamic_weights(self, wall_mask):
|
||||
"""基于墙体存在性计算动态边权重"""
|
||||
# wall_mask: [B, N]
|
||||
src_nodes = self.base_edge_index[0] # [E]
|
||||
dst_nodes = self.base_edge_index[1] # [E]
|
||||
src_nodes = self.base_edge_index[0]
|
||||
dst_nodes = self.base_edge_index[1]
|
||||
|
||||
# 边权重 = 1 - (源是墙 OR 目标墙)
|
||||
weights = 1 - torch.logical_or(
|
||||
wall_mask[:, src_nodes],
|
||||
wall_mask[:, dst_nodes]
|
||||
).float() # [B, E]
|
||||
|
||||
return weights.unsqueeze(-1) # [B, E, 1]
|
||||
# 让梯度能正确回传
|
||||
weights = 1 - (wall_mask[:, src_nodes] + wall_mask[:, dst_nodes]) / 2
|
||||
return weights.unsqueeze(-1)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user