mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-23 03:51:11 +08:00
refactor: 将地图转换移至 shared
This commit is contained in:
parent
41a9e21247
commit
50bb509a84
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from torch_geometric.data import Data
|
from shared.graph import convert_map_to_graph
|
||||||
|
|
||||||
def load_data(path: str):
|
def load_data(path: str):
|
||||||
with open(path, 'r', encoding="utf-8") as f:
|
with open(path, 'r', encoding="utf-8") as f:
|
||||||
@ -13,33 +13,6 @@ def load_data(path: str):
|
|||||||
|
|
||||||
return data_list
|
return data_list
|
||||||
|
|
||||||
def convert_map_to_graph(map):
|
|
||||||
rows = len(map)
|
|
||||||
cols = len(map[0])
|
|
||||||
node_indices = {}
|
|
||||||
valid_nodes = []
|
|
||||||
node_counter = 0
|
|
||||||
|
|
||||||
for r in range(rows):
|
|
||||||
for c in range(cols):
|
|
||||||
if map[r][c] != 1: # 排除墙体
|
|
||||||
node_indices[(r, c)] = node_counter
|
|
||||||
valid_nodes.append((r, c, map[r][c])) # (行, 列, 地形类型)
|
|
||||||
node_counter += 1
|
|
||||||
|
|
||||||
edge_list = []
|
|
||||||
for (r, c, _) in valid_nodes:
|
|
||||||
node = node_indices[(r, c)]
|
|
||||||
if c + 1 < cols and (r, c + 1) in node_indices:
|
|
||||||
edge_list.append((node, node_indices[(r, c + 1)]))
|
|
||||||
if r + 1 < rows and (r + 1, c) in node_indices:
|
|
||||||
edge_list.append((node, node_indices[(r + 1, c)]))
|
|
||||||
|
|
||||||
edge_index = torch.tensor(edge_list, dtype=torch.long).T
|
|
||||||
node_features = torch.tensor([node_type for (_, _, node_type) in valid_nodes], dtype=torch.long)
|
|
||||||
|
|
||||||
return Data(x=node_features, edge_index=edge_index)
|
|
||||||
|
|
||||||
class MinamoDataset(Dataset):
|
class MinamoDataset(Dataset):
|
||||||
def __init__(self, data_path: str):
|
def __init__(self, data_path: str):
|
||||||
self.data = load_data(data_path) # 自定义数据加载函数
|
self.data = load_data(data_path) # 自定义数据加载函数
|
||||||
|
|||||||
29
shared/graph.py
Normal file
29
shared/graph.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import torch
|
||||||
|
from torch_geometric import Data
|
||||||
|
|
||||||
|
def convert_map_to_graph(map):
|
||||||
|
rows = len(map)
|
||||||
|
cols = len(map[0])
|
||||||
|
node_indices = {}
|
||||||
|
valid_nodes = []
|
||||||
|
node_counter = 0
|
||||||
|
|
||||||
|
for r in range(rows):
|
||||||
|
for c in range(cols):
|
||||||
|
if map[r][c] != 1: # 排除墙体
|
||||||
|
node_indices[(r, c)] = node_counter
|
||||||
|
valid_nodes.append((r, c, map[r][c])) # (行, 列, 地形类型)
|
||||||
|
node_counter += 1
|
||||||
|
|
||||||
|
edge_list = []
|
||||||
|
for (r, c, _) in valid_nodes:
|
||||||
|
node = node_indices[(r, c)]
|
||||||
|
if c + 1 < cols and (r, c + 1) in node_indices:
|
||||||
|
edge_list.append((node, node_indices[(r, c + 1)]))
|
||||||
|
if r + 1 < rows and (r + 1, c) in node_indices:
|
||||||
|
edge_list.append((node, node_indices[(r + 1, c)]))
|
||||||
|
|
||||||
|
edge_index = torch.tensor(edge_list, dtype=torch.long).T
|
||||||
|
node_features = torch.tensor([node_type for (_, _, node_type) in valid_nodes], dtype=torch.long)
|
||||||
|
|
||||||
|
return Data(x=node_features, edge_index=edge_index)
|
||||||
Loading…
Reference in New Issue
Block a user