mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-15 05:11:10 +08:00
152 lines
5.6 KiB
Python
152 lines
5.6 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from torch_geometric.data import Data, Batch
|
||
from torch_geometric.utils import dense_to_sparse
|
||
|
||
def convert_map_to_graph(map):
|
||
rows = len(map)
|
||
cols = len(map[0])
|
||
node_indices = {}
|
||
valid_nodes = []
|
||
node_counter = 0
|
||
|
||
for r in range(rows):
|
||
for c in range(cols):
|
||
if map[r][c] != 1: # 排除墙体
|
||
node_indices[(r, c)] = node_counter
|
||
valid_nodes.append((r, c, map[r][c])) # (行, 列, 地形类型)
|
||
node_counter += 1
|
||
|
||
edge_list = []
|
||
for (r, c, _) in valid_nodes:
|
||
node = node_indices[(r, c)]
|
||
if c + 1 < cols and (r, c + 1) in node_indices:
|
||
edge_list.append((node, node_indices[(r, c + 1)]))
|
||
if r + 1 < rows and (r + 1, c) in node_indices:
|
||
edge_list.append((node, node_indices[(r + 1, c)]))
|
||
|
||
edge_index = torch.tensor(edge_list, dtype=torch.long).T
|
||
node_features = torch.tensor([node_type for (_, _, node_type) in valid_nodes], dtype=torch.long)
|
||
|
||
return Data(x=node_features, edge_index=edge_index)
|
||
|
||
def soft_convert_map_to_graph(map_tensor, tau=0.5, threshold=0.1):
|
||
"""
|
||
将地图批量转换为 GNN 可用的 Graph Data,使用 soft 策略
|
||
"""
|
||
B, H, W, C = map_tensor.shape
|
||
N = H * W
|
||
|
||
# 使用 Gumbel-Softmax 确保是 one-hot 形式
|
||
y = F.gumbel_softmax(map_tensor.view(B, N, C), tau=tau, hard=True) # [B, N, C]
|
||
|
||
# 计算整数索引(用于 embedding)
|
||
node_features = y.argmax(dim=-1).long().unsqueeze(-1) # [B, N, 1]
|
||
|
||
# 取出墙体的 soft mask
|
||
wall_mask = y[:, :, 1] # [B, N]
|
||
|
||
adjacency_matrix = torch.zeros(B, N, N, device=map_tensor.device)
|
||
|
||
def get_index(r, c):
|
||
return r * W + c
|
||
|
||
for r in range(H):
|
||
for c in range(W):
|
||
idx = get_index(r, c)
|
||
if c + 1 < W:
|
||
right_idx = get_index(r, c + 1)
|
||
adjacency_matrix[:, idx, right_idx] = (1 - wall_mask[:, idx]) * (1 - wall_mask[:, right_idx])
|
||
if r + 1 < H:
|
||
down_idx = get_index(r + 1, c)
|
||
adjacency_matrix[:, idx, down_idx] = (1 - wall_mask[:, idx]) * (1 - wall_mask[:, down_idx])
|
||
|
||
edge_index_list, edge_weight_list = [], []
|
||
for b in range(B):
|
||
adj_bin = (adjacency_matrix[b] > threshold).float()
|
||
edge_index = torch.nonzero(adj_bin, as_tuple=False).T # [2, E]
|
||
edge_weight = adjacency_matrix[b][edge_index[0], edge_index[1]]
|
||
|
||
edge_index_list.append(edge_index)
|
||
edge_weight_list.append(edge_weight)
|
||
|
||
edge_index = torch.cat(edge_index_list, dim=1)
|
||
edge_weight = torch.cat(edge_weight_list, dim=0)
|
||
|
||
return Data(x=node_features, edge_index=edge_index, edge_attr=edge_weight)
|
||
|
||
class DynamicGraphConverter(nn.Module):
|
||
def __init__(self, map_size=13):
|
||
super().__init__()
|
||
self.map_size = map_size
|
||
self.n_nodes = map_size * map_size
|
||
|
||
# 预计算所有可能的边索引组合(包括对角线)
|
||
self.base_edge_index = self._precompute_base_edges()
|
||
|
||
def _precompute_base_edges(self):
|
||
"""预生成全连接边索引(包含所有可能邻接)"""
|
||
edge_list = []
|
||
directions = [
|
||
(0, 1), # 右
|
||
(1, 0), # 下
|
||
]
|
||
|
||
for r in range(self.map_size):
|
||
for c in range(self.map_size):
|
||
node = r * self.map_size + c
|
||
for dr, dc in directions:
|
||
nr, nc = r + dr, c + dc
|
||
if 0 <= nr < self.map_size and 0 <= nc < self.map_size:
|
||
neighbor = nr * self.map_size + nc
|
||
edge_list.append([node, neighbor])
|
||
|
||
return torch.tensor(edge_list).t().contiguous().unique(dim=1)
|
||
|
||
def forward(self, map_probs, tau=0.5):
|
||
B, C, H, W = map_probs.shape
|
||
device = map_probs.device
|
||
|
||
self.base_edge_index = self.base_edge_index.to(device)
|
||
|
||
# 1. 节点特征离散化(保持可导)
|
||
node_logits = map_probs.view(B, C, -1).permute(0, 2, 1) # [B, N, C]
|
||
hard_nodes = F.gumbel_softmax(node_logits, tau=tau, hard=True)
|
||
node_ids = hard_nodes.argmax(dim=-1) # [B, N]
|
||
|
||
# 2. 动态边权重计算
|
||
wall_mask = (node_ids == 1).float() # 假设类别1是墙体
|
||
edge_weights = self._compute_dynamic_weights(wall_mask)
|
||
|
||
# 3. 构建动态图
|
||
batch_data = []
|
||
for b in range(B):
|
||
# 动态过滤无效边(与墙体相连的边)
|
||
valid_mask = (edge_weights[b] > 0.1).squeeze(-1)
|
||
dynamic_edge_index = self.base_edge_index[:, valid_mask]
|
||
dynamic_edge_attr = edge_weights[b][valid_mask]
|
||
|
||
data = Data(
|
||
x=node_ids[b],
|
||
edge_index=dynamic_edge_index,
|
||
edge_attr=dynamic_edge_attr
|
||
)
|
||
batch_data.append(data)
|
||
|
||
return Batch.from_data_list(batch_data)
|
||
|
||
def _compute_dynamic_weights(self, wall_mask):
|
||
"""基于墙体存在性计算动态边权重"""
|
||
# wall_mask: [B, N]
|
||
src_nodes = self.base_edge_index[0] # [E]
|
||
dst_nodes = self.base_edge_index[1] # [E]
|
||
|
||
# 边权重 = 1 - (源是墙 OR 目标墙)
|
||
weights = 1 - torch.logical_or(
|
||
wall_mask[:, src_nodes],
|
||
wall_mask[:, dst_nodes]
|
||
).float() # [B, E]
|
||
|
||
return weights.unsqueeze(-1) # [B, E, 1]
|