ginka-generator/ginka/dataset.py

118 lines
4.0 KiB
Python

import json
import random
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from minamo.model.model import MinamoModel
from shared.graph import differentiable_convert_to_data
from shared.utils import random_smooth_onehot
def load_data(path: str):
with open(path, 'r', encoding="utf-8") as f:
data = json.load(f)
data_list = []
for value in data["data"].values():
data_list.append(value)
return data_list
def load_minamo_gan_data(data: list):
res = list()
for one in data:
res.append((one['map1'], one['map2'], one['visionSimilarity'], one['topoSimilarity'], True))
return res
class GinkaDataset(Dataset):
def __init__(self, data_path: str, device, minamo: MinamoModel):
self.data = load_data(data_path) # 自定义数据加载函数
self.max_size = 32
self.minamo = minamo
self.device = device
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
min_main = random.uniform(0.75, 0.9)
max_main = random.uniform(0.9, 1)
epsilon = random.uniform(0, 0.25)
target_smooth = random_smooth_onehot(target, min_main, max_main, epsilon)
graph = differentiable_convert_to_data(target_smooth).to(self.device)
target = target.to(self.device)
vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph)
return {
"target_vision_feat": vision_feat,
"target_topo_feat": topo_feat,
"target": target,
}
class GinkaWGANDataset(Dataset):
def __init__(self, data_path: str, device):
self.data = load_data(data_path) # 自定义数据加载函数
self.device = device
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
min_main = random.uniform(0.75, 0.9)
max_main = random.uniform(0.9, 1)
epsilon = random.uniform(0, 0.25)
target_smooth = random_smooth_onehot(target, min_main, max_main, epsilon).to(self.device)
return target_smooth
class MinamoGANDataset(Dataset):
def __init__(self, refer_data_path):
self.refer = load_minamo_gan_data(load_data(refer_data_path))
self.data = list()
self.data.extend(random.sample(self.refer, 1000))
def set_data(self, data: list):
self.data.clear()
self.data.extend(data)
k = min(len(data) / 4, len(self.refer))
self.data.extend(random.sample(self.refer, int(k)))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
# 假定 map2 是参考地图
item = self.data[idx]
map1, map2, vis_sim, topo_sim, review = item
# 检查是否有 review 标签,没有的话说明是概率分布,不需要任何转换
if review:
map1 = F.one_hot(torch.LongTensor(map1), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
else:
map1 = torch.FloatTensor(map1)
map2 = F.one_hot(torch.LongTensor(map2), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
min_main = random.uniform(0.75, 0.9)
max_main = random.uniform(0.9, 1)
epsilon = random.uniform(0, 0.25)
if review:
map1 = random_smooth_onehot(map1, min_main, max_main, epsilon)
map2 = random_smooth_onehot(map2, min_main, max_main, epsilon)
graph1 = differentiable_convert_to_data(map1)
graph2 = differentiable_convert_to_data(map2)
return (
map1,
map2,
torch.FloatTensor([vis_sim]),
torch.FloatTensor([topo_sim]),
graph1,
graph2
)