mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-15 13:21:09 +08:00
60 lines
2.0 KiB
Python
60 lines
2.0 KiB
Python
import json
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import Dataset
|
|
from shared.graph import convert_soft_map_to_graph
|
|
|
|
def random_smooth_onehot(onehot_map, min_main=0.65, max_main=1.0, epsilon=0.35):
|
|
"""
|
|
生成随机平滑的 one-hot 编码,使主类别概率不再固定,而是随机波动
|
|
"""
|
|
C, H, W = onehot_map.shape
|
|
# 生成主类别的随机概率 (min_main, max_main)
|
|
main_prob = torch.rand(H, W) * (max_main - min_main) + min_main
|
|
|
|
# 计算剩余概率并随机分配到其他类别
|
|
noise = torch.rand(C, H, W) * epsilon # 随机噪声
|
|
noise = noise / noise.sum(dim=1, keepdim=True) # 归一化到总和为 epsilon
|
|
|
|
# 计算最终平滑 one-hot 结果
|
|
smooth_onehot = onehot_map * main_prob + (1 - onehot_map) * noise
|
|
return smooth_onehot
|
|
|
|
def load_data(path: str):
|
|
with open(path, 'r', encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
|
|
data_list = []
|
|
for value in data["data"].values():
|
|
data_list.append(value)
|
|
|
|
return data_list
|
|
|
|
class MinamoDataset(Dataset):
|
|
def __init__(self, data_path: str):
|
|
self.data = load_data(data_path) # 自定义数据加载函数
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
item = self.data[idx]
|
|
|
|
map1_probs = F.one_hot(torch.LongTensor(item['map1']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
|
map2_probs = F.one_hot(torch.LongTensor(item['map2']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
|
|
|
map1_probs = random_smooth_onehot(map1_probs)
|
|
map2_probs = random_smooth_onehot(map2_probs)
|
|
|
|
graph1 = convert_soft_map_to_graph(map1_probs)
|
|
graph2 = convert_soft_map_to_graph(map2_probs)
|
|
|
|
return (
|
|
map1_probs,
|
|
map2_probs,
|
|
torch.FloatTensor([item['visionSimilarity']]),
|
|
torch.FloatTensor([item['topoSimilarity']]),
|
|
graph1,
|
|
graph2
|
|
)
|