ginka-generator/shared/graph.py

120 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
def convert_soft_map_to_graph(map_probs):
"""
直接使用 Softmax 概率构建 soft 图结构
"""
C, H, W = map_probs.shape # [32, H, W]
N = H * W
device = map_probs.device
# 计算 soft 节点特征
node_features = map_probs.view(C, N).T # [N, C]
# 计算 soft 邻接边(基于 soft 权重)
edge_list = []
for r in range(H):
for c in range(W):
node = r * W + c
if c + 1 < W:
right = node + 1
edge_list.append([node, right])
if r + 1 < H:
down = node + W
edge_list.append([node, down])
edge_index = torch.tensor(edge_list).t().to(device)
# 计算 soft 边权重(基于 Softmax 概率)
soft_edge_weight = (map_probs[:, edge_index[0] // W, edge_index[0] % W] +
map_probs[:, edge_index[1] // W, edge_index[1] % W]) / 2
return Data(x=node_features, edge_index=edge_index, edge_attr=soft_edge_weight)
def convert_map_to_graph(map):
rows = len(map)
cols = len(map[0])
node_indices = {}
valid_nodes = []
node_counter = 0
for r in range(rows):
for c in range(cols):
if map[r][c] != 1: # 排除墙体
node_indices[(r, c)] = node_counter
valid_nodes.append((r, c, map[r][c])) # (行, 列, 地形类型)
node_counter += 1
edge_list = []
for (r, c, _) in valid_nodes:
node = node_indices[(r, c)]
if c + 1 < cols and (r, c + 1) in node_indices:
edge_list.append((node, node_indices[(r, c + 1)]))
if r + 1 < rows and (r + 1, c) in node_indices:
edge_list.append((node, node_indices[(r + 1, c)]))
edge_index = torch.tensor(edge_list, dtype=torch.long).T
node_features = torch.tensor([node_type for (_, _, node_type) in valid_nodes], dtype=torch.long)
return Data(x=node_features, edge_index=edge_index)
class DynamicGraphConverter(nn.Module):
def __init__(self, map_size=13):
super().__init__()
self.map_size = map_size
self.n_nodes = map_size * map_size
self.base_edge_index = self._precompute_base_edges()
def _precompute_base_edges(self):
edge_list = []
directions = [(0, 1), (1, 0)]
for r in range(self.map_size):
for c in range(self.map_size):
node = r * self.map_size + c
for dr, dc in directions:
nr, nc = r + dr, c + dc
if 0 <= nr < self.map_size and 0 <= nc < self.map_size:
neighbor = nr * self.map_size + nc
edge_list.append([node, neighbor])
return torch.tensor(edge_list).t().contiguous().unique(dim=1)
def forward(self, map_probs, tau=0.5):
B, C, H, W = map_probs.shape
device = map_probs.device
self.base_edge_index = self.base_edge_index.to(device)
# 1. 计算可微的节点 ID
node_logits = map_probs.view(B, C, -1).permute(0, 2, 1) # [B, N, C]
hard_nodes = F.gumbel_softmax(node_logits, tau=tau, hard=True)
node_ids = (hard_nodes * torch.arange(C, device=device).view(1, 1, -1)).sum(dim=-1).long()
# 2. 计算 soft 壁障 mask
wall_mask = torch.sigmoid((node_ids - 1) * 10) # 类别 1 代表墙体soft 处理
edge_weights = self._compute_dynamic_weights(wall_mask)
# 3. 构建动态图
batch_data = []
for b in range(B):
soft_mask = torch.sigmoid((edge_weights[b] - 0.1) * 10) # 软门控
dynamic_edge_attr = edge_weights[b] * soft_mask # 仍然保留梯度
data = Data(
x=node_ids[b],
edge_index=self.base_edge_index,
edge_attr=dynamic_edge_attr
)
batch_data.append(data)
return Batch.from_data_list(batch_data)
def _compute_dynamic_weights(self, wall_mask):
src_nodes = self.base_edge_index[0]
dst_nodes = self.base_edge_index[1]
# 让梯度能正确回传
weights = 1 - (wall_mask[:, src_nodes] + wall_mask[:, dst_nodes]) / 2
return weights.unsqueeze(-1)