diff --git a/minamo/dataset.py b/minamo/dataset.py index 0ff667e..afd253c 100644 --- a/minamo/dataset.py +++ b/minamo/dataset.py @@ -1,7 +1,7 @@ import json import torch from torch.utils.data import Dataset -from torch_geometric.data import Data +from shared.graph import convert_map_to_graph def load_data(path: str): with open(path, 'r', encoding="utf-8") as f: @@ -13,33 +13,6 @@ def load_data(path: str): return data_list -def convert_map_to_graph(map): - rows = len(map) - cols = len(map[0]) - node_indices = {} - valid_nodes = [] - node_counter = 0 - - for r in range(rows): - for c in range(cols): - if map[r][c] != 1: # 排除墙体 - node_indices[(r, c)] = node_counter - valid_nodes.append((r, c, map[r][c])) # (行, 列, 地形类型) - node_counter += 1 - - edge_list = [] - for (r, c, _) in valid_nodes: - node = node_indices[(r, c)] - if c + 1 < cols and (r, c + 1) in node_indices: - edge_list.append((node, node_indices[(r, c + 1)])) - if r + 1 < rows and (r + 1, c) in node_indices: - edge_list.append((node, node_indices[(r + 1, c)])) - - edge_index = torch.tensor(edge_list, dtype=torch.long).T - node_features = torch.tensor([node_type for (_, _, node_type) in valid_nodes], dtype=torch.long) - - return Data(x=node_features, edge_index=edge_index) - class MinamoDataset(Dataset): def __init__(self, data_path: str): self.data = load_data(data_path) # 自定义数据加载函数 diff --git a/shared/graph.py b/shared/graph.py new file mode 100644 index 0000000..187051c --- /dev/null +++ b/shared/graph.py @@ -0,0 +1,29 @@ +import torch +from torch_geometric import Data + +def convert_map_to_graph(map): + rows = len(map) + cols = len(map[0]) + node_indices = {} + valid_nodes = [] + node_counter = 0 + + for r in range(rows): + for c in range(cols): + if map[r][c] != 1: # 排除墙体 + node_indices[(r, c)] = node_counter + valid_nodes.append((r, c, map[r][c])) # (行, 列, 地形类型) + node_counter += 1 + + edge_list = [] + for (r, c, _) in valid_nodes: + node = node_indices[(r, c)] + if c + 1 < cols and (r, c + 1) in node_indices: + edge_list.append((node, node_indices[(r, c + 1)])) + if r + 1 < rows and (r + 1, c) in node_indices: + edge_list.append((node, node_indices[(r + 1, c)])) + + edge_index = torch.tensor(edge_list, dtype=torch.long).T + node_features = torch.tensor([node_type for (_, _, node_type) in valid_nodes], dtype=torch.long) + + return Data(x=node_features, edge_index=edge_index) \ No newline at end of file