mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-17 06:51:11 +08:00
60 lines
1.8 KiB
Python
60 lines
1.8 KiB
Python
import json
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
from torch_geometric.data import Data
|
|
|
|
def load_data(path: str):
|
|
with open(path, 'r', encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
|
|
data_list = []
|
|
for value in data["data"].values():
|
|
data_list.append(value)
|
|
|
|
return data_list
|
|
|
|
def convert_map_to_graph(map):
|
|
rows = len(map)
|
|
cols = len(map[0])
|
|
node_indices = {}
|
|
valid_nodes = []
|
|
node_counter = 0
|
|
|
|
for r in range(rows):
|
|
for c in range(cols):
|
|
if map[r][c] != 1: # 排除墙体
|
|
node_indices[(r, c)] = node_counter
|
|
valid_nodes.append((r, c, map[r][c])) # (行, 列, 地形类型)
|
|
node_counter += 1
|
|
|
|
edge_list = []
|
|
for (r, c, _) in valid_nodes:
|
|
node = node_indices[(r, c)]
|
|
if c + 1 < cols and (r, c + 1) in node_indices:
|
|
edge_list.append((node, node_indices[(r, c + 1)]))
|
|
if r + 1 < rows and (r + 1, c) in node_indices:
|
|
edge_list.append((node, node_indices[(r + 1, c)]))
|
|
|
|
edge_index = torch.tensor(edge_list, dtype=torch.long).T
|
|
node_features = torch.tensor([node_type for (_, _, node_type) in valid_nodes], dtype=torch.long)
|
|
|
|
return Data(x=node_features, edge_index=edge_index)
|
|
|
|
class MinamoDataset(Dataset):
|
|
def __init__(self, data_path: str):
|
|
self.data = load_data(data_path) # 自定义数据加载函数
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
item = self.data[idx]
|
|
return (
|
|
torch.LongTensor(item['map1']),
|
|
torch.LongTensor(item['map2']),
|
|
torch.FloatTensor([item['visionSimilarity']]),
|
|
torch.FloatTensor([item['topoSimilarity']]),
|
|
convert_map_to_graph(item['map1']),
|
|
convert_map_to_graph(item['map2'])
|
|
)
|