ginka-generator/ginka/dataset.py
2026-03-31 22:24:25 +08:00

77 lines
2.4 KiB
Python

import json
import random
import torch
import cv2
import numpy as np
from torch.utils.data import Dataset
def load_data(path: str):
with open(path, 'r', encoding="utf-8") as f:
data = json.load(f)
data_list = []
for value in data["data"].values():
data_list.append(value)
return data_list
class GinkaMaskGITDataset(Dataset):
def __init__(self, data_path: str, sigma_rand=0.1, blur_min=3, blur_max=6):
self.data = load_data(data_path)
self.sigma_rand = sigma_rand
self.blur_min = blur_min
self.blur_max = blur_max
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
target_np = np.array(item['map'])
heatmap = np.array(item['heatmap'], dtype=np.float32)
# 数据增强
if np.random.rand() > 0.5:
k = np.random.randint(0, 4)
target_np = np.rot90(target_np, k)
heatmap = np.rot90(heatmap, k)
if np.random.rand() > 0.5:
target_np = np.fliplr(target_np)
heatmap = np.fliplr(heatmap)
if np.random.rand() > 0.5:
target_np = np.flipud(target_np)
heatmap = np.flipud(heatmap)
target = torch.LongTensor(target_np) # [H, W]
cond = torch.FloatTensor(item['val']) # [cond_dim]
if random.random() < 0.5:
size = random.randint(self.blur_min, self.blur_max)
if size % 2 == 0:
size = size + 1 if random.random() < 0.5 else size - 1
heatmap = cv2.GaussianBlur(heatmap, (size, size), 0)
else:
sizeX = random.randint(self.blur_min, self.blur_max)
sizeY = random.randint(self.blur_min, self.blur_max)
if sizeX % 2 == 0:
sizeX = sizeX + 1 if random.random() < 0.5 else sizeX - 1
if sizeY % 2 == 0:
sizeY = sizeY + 1 if random.random() < 0.5 else sizeY - 1
heatmap = cv2.GaussianBlur(heatmap, (sizeX, sizeY), 0)
heatmap = torch.FloatTensor(heatmap) # [heatmap_channel, H, W]
if random.random() < 0.5:
sigma = random.random() * self.sigma_rand
rand = torch.randn_like(heatmap) * sigma
heatmap = heatmap + rand
return {
"cond": cond,
"target_map": target,
"heatmap": heatmap
}