mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-17 23:21:20 +08:00
41 lines
1.1 KiB
Python
41 lines
1.1 KiB
Python
import json
|
|
import math
|
|
import random
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import Dataset
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from typing import List
|
|
|
|
def load_data(path: str):
|
|
with open(path, 'r', encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
|
|
data_list = []
|
|
for value in data["data"].values():
|
|
data_list.append(value)
|
|
|
|
return data_list
|
|
|
|
class GinkaMaskGITDataset(Dataset):
|
|
def __init__(self, data_path: str, device):
|
|
self.data = load_data(data_path) # 自定义数据加载函数
|
|
self.device = device
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
item = self.data[idx]
|
|
|
|
target = torch.LongTensor(item['map']) # [H, W]
|
|
cond = torch.FloatTensor(item['val']) # [cond_dim]
|
|
heatmap = torch.FloatTensor(item['heatmap']) # [heatmap_channel, H, W]
|
|
|
|
return {
|
|
"cond": cond,
|
|
"target_map": target,
|
|
"heatmap": heatmap
|
|
}
|
|
|