mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-18 15:41:11 +08:00
refactor: 将地图转换移至 shared
This commit is contained in:
parent
41a9e21247
commit
50bb509a84
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torch_geometric.data import Data
|
||||
from shared.graph import convert_map_to_graph
|
||||
|
||||
def load_data(path: str):
|
||||
with open(path, 'r', encoding="utf-8") as f:
|
||||
@ -13,33 +13,6 @@ def load_data(path: str):
|
||||
|
||||
return data_list
|
||||
|
||||
def convert_map_to_graph(map):
|
||||
rows = len(map)
|
||||
cols = len(map[0])
|
||||
node_indices = {}
|
||||
valid_nodes = []
|
||||
node_counter = 0
|
||||
|
||||
for r in range(rows):
|
||||
for c in range(cols):
|
||||
if map[r][c] != 1: # 排除墙体
|
||||
node_indices[(r, c)] = node_counter
|
||||
valid_nodes.append((r, c, map[r][c])) # (行, 列, 地形类型)
|
||||
node_counter += 1
|
||||
|
||||
edge_list = []
|
||||
for (r, c, _) in valid_nodes:
|
||||
node = node_indices[(r, c)]
|
||||
if c + 1 < cols and (r, c + 1) in node_indices:
|
||||
edge_list.append((node, node_indices[(r, c + 1)]))
|
||||
if r + 1 < rows and (r + 1, c) in node_indices:
|
||||
edge_list.append((node, node_indices[(r + 1, c)]))
|
||||
|
||||
edge_index = torch.tensor(edge_list, dtype=torch.long).T
|
||||
node_features = torch.tensor([node_type for (_, _, node_type) in valid_nodes], dtype=torch.long)
|
||||
|
||||
return Data(x=node_features, edge_index=edge_index)
|
||||
|
||||
class MinamoDataset(Dataset):
|
||||
def __init__(self, data_path: str):
|
||||
self.data = load_data(data_path) # 自定义数据加载函数
|
||||
|
||||
29
shared/graph.py
Normal file
29
shared/graph.py
Normal file
@ -0,0 +1,29 @@
|
||||
import torch
|
||||
from torch_geometric import Data
|
||||
|
||||
def convert_map_to_graph(map):
|
||||
rows = len(map)
|
||||
cols = len(map[0])
|
||||
node_indices = {}
|
||||
valid_nodes = []
|
||||
node_counter = 0
|
||||
|
||||
for r in range(rows):
|
||||
for c in range(cols):
|
||||
if map[r][c] != 1: # 排除墙体
|
||||
node_indices[(r, c)] = node_counter
|
||||
valid_nodes.append((r, c, map[r][c])) # (行, 列, 地形类型)
|
||||
node_counter += 1
|
||||
|
||||
edge_list = []
|
||||
for (r, c, _) in valid_nodes:
|
||||
node = node_indices[(r, c)]
|
||||
if c + 1 < cols and (r, c + 1) in node_indices:
|
||||
edge_list.append((node, node_indices[(r, c + 1)]))
|
||||
if r + 1 < rows and (r + 1, c) in node_indices:
|
||||
edge_list.append((node, node_indices[(r + 1, c)]))
|
||||
|
||||
edge_index = torch.tensor(edge_list, dtype=torch.long).T
|
||||
node_features = torch.tensor([node_type for (_, _, node_type) in valid_nodes], dtype=torch.long)
|
||||
|
||||
return Data(x=node_features, edge_index=edge_index)
|
||||
Loading…
Reference in New Issue
Block a user