From 50bb509a8453a98a4de9f62b24a4d82c6e0d3f68 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Sun, 16 Mar 2025 23:51:31 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E5=B0=86=E5=9C=B0=E5=9B=BE?= =?UTF-8?q?=E8=BD=AC=E6=8D=A2=E7=A7=BB=E8=87=B3=20shared?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- minamo/dataset.py | 29 +---------------------------- shared/graph.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 28 deletions(-) create mode 100644 shared/graph.py diff --git a/minamo/dataset.py b/minamo/dataset.py index 0ff667e..afd253c 100644 --- a/minamo/dataset.py +++ b/minamo/dataset.py @@ -1,7 +1,7 @@ import json import torch from torch.utils.data import Dataset -from torch_geometric.data import Data +from shared.graph import convert_map_to_graph def load_data(path: str): with open(path, 'r', encoding="utf-8") as f: @@ -13,33 +13,6 @@ def load_data(path: str): return data_list -def convert_map_to_graph(map): - rows = len(map) - cols = len(map[0]) - node_indices = {} - valid_nodes = [] - node_counter = 0 - - for r in range(rows): - for c in range(cols): - if map[r][c] != 1: # 排除墙体 - node_indices[(r, c)] = node_counter - valid_nodes.append((r, c, map[r][c])) # (行, 列, 地形类型) - node_counter += 1 - - edge_list = [] - for (r, c, _) in valid_nodes: - node = node_indices[(r, c)] - if c + 1 < cols and (r, c + 1) in node_indices: - edge_list.append((node, node_indices[(r, c + 1)])) - if r + 1 < rows and (r + 1, c) in node_indices: - edge_list.append((node, node_indices[(r + 1, c)])) - - edge_index = torch.tensor(edge_list, dtype=torch.long).T - node_features = torch.tensor([node_type for (_, _, node_type) in valid_nodes], dtype=torch.long) - - return Data(x=node_features, edge_index=edge_index) - class MinamoDataset(Dataset): def __init__(self, data_path: str): self.data = load_data(data_path) # 自定义数据加载函数 diff --git a/shared/graph.py b/shared/graph.py new file mode 100644 index 0000000..187051c --- /dev/null +++ b/shared/graph.py @@ -0,0 +1,29 @@ +import torch +from torch_geometric import Data + +def convert_map_to_graph(map): + rows = len(map) + cols = len(map[0]) + node_indices = {} + valid_nodes = [] + node_counter = 0 + + for r in range(rows): + for c in range(cols): + if map[r][c] != 1: # 排除墙体 + node_indices[(r, c)] = node_counter + valid_nodes.append((r, c, map[r][c])) # (行, 列, 地形类型) + node_counter += 1 + + edge_list = [] + for (r, c, _) in valid_nodes: + node = node_indices[(r, c)] + if c + 1 < cols and (r, c + 1) in node_indices: + edge_list.append((node, node_indices[(r, c + 1)])) + if r + 1 < rows and (r + 1, c) in node_indices: + edge_list.append((node, node_indices[(r + 1, c)])) + + edge_index = torch.tensor(edge_list, dtype=torch.long).T + node_features = torch.tensor([node_type for (_, _, node_type) in valid_nodes], dtype=torch.long) + + return Data(x=node_features, edge_index=edge_index) \ No newline at end of file