mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 21:01:12 +08:00
48 lines
1.5 KiB
Python
48 lines
1.5 KiB
Python
import torch
|
|
from torch_geometric.data import Data, Batch
|
|
|
|
def convert_soft_map_to_graph(map_probs: torch.Tensor):
|
|
"""
|
|
直接使用 Softmax 概率构建 soft 图结构
|
|
"""
|
|
C, H, W = map_probs.shape # [32, H, W]
|
|
N = H * W
|
|
device = map_probs.device
|
|
|
|
# 计算 soft 节点特征
|
|
node_features = map_probs.view(C, N).T # [N, C]
|
|
|
|
# 计算 soft 邻接边(基于 soft 权重)
|
|
edge_list = []
|
|
for r in range(H):
|
|
for c in range(W):
|
|
node = r * W + c
|
|
if c + 1 < W:
|
|
right = node + 1
|
|
edge_list.append([node, right])
|
|
if r + 1 < H:
|
|
down = node + W
|
|
edge_list.append([node, down])
|
|
|
|
edge_index = torch.tensor(edge_list).t().to(device)
|
|
|
|
# 计算 soft 边权重(基于 Softmax 概率)
|
|
soft_edge_weight = (map_probs[:, edge_index[0] // W, edge_index[0] % W] +
|
|
map_probs[:, edge_index[1] // W, edge_index[1] % W]) / 2
|
|
|
|
return Data(x=node_features, edge_index=edge_index, edge_attr=soft_edge_weight)
|
|
|
|
def batch_convert_soft_map_to_graph(batch_map_probs):
|
|
"""
|
|
处理 batch 维度,将 [B, C, H, W] 转换为批量图结构 Batch
|
|
"""
|
|
B, C, H, W = batch_map_probs.shape # 获取 batch 维度
|
|
batch_graphs = []
|
|
|
|
for i in range(B):
|
|
graph = convert_soft_map_to_graph(batch_map_probs[i]) # 处理单个样本
|
|
batch_graphs.append(graph)
|
|
|
|
# 合并所有图为批量 Batch
|
|
return Batch.from_data_list(batch_graphs)
|