mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
118 lines
4.0 KiB
Python
118 lines
4.0 KiB
Python
import json
|
|
import random
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import Dataset
|
|
from minamo.model.model import MinamoModel
|
|
from shared.graph import differentiable_convert_to_data
|
|
from shared.utils import random_smooth_onehot
|
|
|
|
def load_data(path: str):
|
|
with open(path, 'r', encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
|
|
data_list = []
|
|
for value in data["data"].values():
|
|
data_list.append(value)
|
|
|
|
return data_list
|
|
|
|
def load_minamo_gan_data(data: list):
|
|
res = list()
|
|
for one in data:
|
|
res.append((one['map1'], one['map2'], one['visionSimilarity'], one['topoSimilarity'], True))
|
|
return res
|
|
|
|
class GinkaDataset(Dataset):
|
|
def __init__(self, data_path: str, device, minamo: MinamoModel):
|
|
self.data = load_data(data_path) # 自定义数据加载函数
|
|
self.max_size = 32
|
|
self.minamo = minamo
|
|
self.device = device
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
item = self.data[idx]
|
|
|
|
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
|
min_main = random.uniform(0.75, 0.9)
|
|
max_main = random.uniform(0.9, 1)
|
|
epsilon = random.uniform(0, 0.25)
|
|
target_smooth = random_smooth_onehot(target, min_main, max_main, epsilon)
|
|
graph = differentiable_convert_to_data(target_smooth).to(self.device)
|
|
target = target.to(self.device)
|
|
vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph)
|
|
|
|
return {
|
|
"target_vision_feat": vision_feat,
|
|
"target_topo_feat": topo_feat,
|
|
"target": target,
|
|
}
|
|
|
|
class GinkaWGANDataset(Dataset):
|
|
def __init__(self, data_path: str, device):
|
|
self.data = load_data(data_path) # 自定义数据加载函数
|
|
self.device = device
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
item = self.data[idx]
|
|
|
|
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
|
min_main = random.uniform(0.75, 0.9)
|
|
max_main = random.uniform(0.9, 1)
|
|
epsilon = random.uniform(0, 0.25)
|
|
target_smooth = random_smooth_onehot(target, min_main, max_main, epsilon).to(self.device)
|
|
|
|
return target_smooth
|
|
|
|
class MinamoGANDataset(Dataset):
|
|
def __init__(self, refer_data_path):
|
|
self.refer = load_minamo_gan_data(load_data(refer_data_path))
|
|
self.data = list()
|
|
self.data.extend(random.sample(self.refer, 1000))
|
|
|
|
def set_data(self, data: list):
|
|
self.data.clear()
|
|
self.data.extend(data)
|
|
k = min(len(data) / 4, len(self.refer))
|
|
self.data.extend(random.sample(self.refer, int(k)))
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
# 假定 map2 是参考地图
|
|
item = self.data[idx]
|
|
|
|
map1, map2, vis_sim, topo_sim, review = item
|
|
# 检查是否有 review 标签,没有的话说明是概率分布,不需要任何转换
|
|
if review:
|
|
map1 = F.one_hot(torch.LongTensor(map1), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
|
else:
|
|
map1 = torch.FloatTensor(map1)
|
|
map2 = F.one_hot(torch.LongTensor(map2), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
|
|
|
min_main = random.uniform(0.75, 0.9)
|
|
max_main = random.uniform(0.9, 1)
|
|
epsilon = random.uniform(0, 0.25)
|
|
|
|
if review:
|
|
map1 = random_smooth_onehot(map1, min_main, max_main, epsilon)
|
|
map2 = random_smooth_onehot(map2, min_main, max_main, epsilon)
|
|
|
|
graph1 = differentiable_convert_to_data(map1)
|
|
graph2 = differentiable_convert_to_data(map2)
|
|
|
|
return (
|
|
map1,
|
|
map2,
|
|
torch.FloatTensor([vis_sim]),
|
|
torch.FloatTensor([topo_sim]),
|
|
graph1,
|
|
graph2
|
|
) |