mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
576 lines
22 KiB
Python
576 lines
22 KiB
Python
import json
|
||
import random
|
||
import torch
|
||
import cv2
|
||
import numpy as np
|
||
from torch.utils.data import Dataset
|
||
|
||
def load_data(path: str):
|
||
with open(path, 'r', encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
|
||
data_list = []
|
||
for value in data["data"].values():
|
||
data_list.append(value)
|
||
|
||
return data_list
|
||
|
||
# 资源类别压缩:将所有资源 tile(钥匙/红宝石/蓝宝石/绿宝石/血瓶/道具)统一映射为 3
|
||
# 其余 tile 保持原始编号(enemy=9, entry=10, mask=15)
|
||
_RESOURCE_REMAP = np.array([0, 1, 2, 3, 3, 3, 3, 3, 3, 9, 10, 11, 12, 13, 14, 15], dtype=np.int64)
|
||
|
||
def remap_resources(arr: np.ndarray) -> np.ndarray:
|
||
"""将地图 numpy 数组中的资源 tile (3~8) 统一压缩为 3。"""
|
||
return _RESOURCE_REMAP[arr]
|
||
|
||
class GinkaMaskGITDataset(Dataset):
|
||
def __init__(
|
||
self, data_path: str, sigma_rand=0.1, blur_min=3, blur_max=6,
|
||
noise_prob=0.2, drop_prob=0.2, noise_sigma=0.1
|
||
):
|
||
self.data = load_data(data_path)
|
||
self.sigma_rand = sigma_rand
|
||
self.blur_min = blur_min
|
||
self.blur_max = blur_max
|
||
self.noise_prob = noise_prob
|
||
self.drop_prob = drop_prob
|
||
self.noise_sigma = noise_sigma
|
||
|
||
def __len__(self):
|
||
return len(self.data)
|
||
|
||
def __getitem__(self, idx):
|
||
item = self.data[idx]
|
||
|
||
target_np = np.array(item['map'])
|
||
heatmap = np.array(item['heatmap'], dtype=np.float32)
|
||
|
||
# 数据增强
|
||
if np.random.rand() > 0.5:
|
||
k = np.random.randint(0, 4)
|
||
target_np = np.rot90(target_np, k)
|
||
for i in range(0, heatmap.shape[0]):
|
||
heatmap[i] = np.rot90(heatmap[i], k)
|
||
|
||
if np.random.rand() > 0.5:
|
||
target_np = np.fliplr(target_np)
|
||
for i in range(0, heatmap.shape[0]):
|
||
heatmap[i] = np.fliplr(heatmap[i])
|
||
|
||
if np.random.rand() > 0.5:
|
||
target_np = np.flipud(target_np)
|
||
for i in range(0, heatmap.shape[0]):
|
||
heatmap[i] = np.flipud(heatmap[i])
|
||
|
||
target = torch.LongTensor(target_np.copy()) # [H, W]
|
||
cond = torch.FloatTensor(item['val']) # [cond_dim]
|
||
|
||
if random.random() < 0.5:
|
||
size = random.randint(self.blur_min, self.blur_max)
|
||
if size % 2 == 0:
|
||
size = size + 1 if random.random() < 0.5 else size - 1
|
||
heatmap = cv2.GaussianBlur(heatmap, (size, size), 0)
|
||
else:
|
||
sizeX = random.randint(self.blur_min, self.blur_max)
|
||
sizeY = random.randint(self.blur_min, self.blur_max)
|
||
if sizeX % 2 == 0:
|
||
sizeX = sizeX + 1 if random.random() < 0.5 else sizeX - 1
|
||
if sizeY % 2 == 0:
|
||
sizeY = sizeY + 1 if random.random() < 0.5 else sizeY - 1
|
||
heatmap = cv2.GaussianBlur(heatmap, (sizeX, sizeY), 0)
|
||
|
||
heatmap = torch.FloatTensor(heatmap) # [heatmap_channel, H, W]
|
||
|
||
for i in range(0, heatmap.shape[0]):
|
||
if np.random.rand() < self.noise_prob:
|
||
sigma = random.random() * self.noise_sigma
|
||
heatmap[i] = heatmap[i] * sigma + torch.rand_like(heatmap[i]) * (1 - sigma)
|
||
elif np.random.rand() < self.drop_prob:
|
||
heatmap[i] = torch.zeros_like(heatmap[i])
|
||
|
||
if random.random() < 0.5:
|
||
sigma = random.random() * self.sigma_rand
|
||
rand = torch.rand_like(heatmap)
|
||
heatmap = heatmap * (1 - sigma) + rand * sigma
|
||
|
||
return {
|
||
"cond": cond,
|
||
"target_map": target,
|
||
"heatmap": heatmap
|
||
}
|
||
|
||
class GinkaHeatmapDataset(Dataset):
|
||
def __init__(self, data_path: str, min_mask=0, max_mask=0.8, blur_min=3, blur_max=6):
|
||
self.data = load_data(data_path)
|
||
self.blur_min = blur_min
|
||
self.blur_max = blur_max
|
||
self.min_mask = min_mask
|
||
self.max_mask = max_mask
|
||
|
||
def __len__(self):
|
||
return len(self.data)
|
||
|
||
def __getitem__(self, idx):
|
||
item = self.data[idx]
|
||
|
||
heatmap = np.array(item['heatmap'], dtype=np.float32)
|
||
|
||
# 数据增强
|
||
if np.random.rand() > 0.5:
|
||
k = np.random.randint(0, 4)
|
||
for i in range(0, heatmap.shape[0]):
|
||
heatmap[i] = np.rot90(heatmap[i], k)
|
||
|
||
if np.random.rand() > 0.5:
|
||
for i in range(0, heatmap.shape[0]):
|
||
heatmap[i] = np.fliplr(heatmap[i])
|
||
|
||
if np.random.rand() > 0.5:
|
||
for i in range(0, heatmap.shape[0]):
|
||
heatmap[i] = np.flipud(heatmap[i])
|
||
|
||
target = heatmap.copy()
|
||
|
||
if random.random() < 0.5:
|
||
size = random.randint(self.blur_min, self.blur_max)
|
||
if size % 2 == 0:
|
||
size = size + 1 if random.random() < 0.5 else size - 1
|
||
target = cv2.GaussianBlur(target, (size, size), 0)
|
||
else:
|
||
sizeX = random.randint(self.blur_min, self.blur_max)
|
||
sizeY = random.randint(self.blur_min, self.blur_max)
|
||
if sizeX % 2 == 0:
|
||
sizeX = sizeX + 1 if random.random() < 0.5 else sizeX - 1
|
||
if sizeY % 2 == 0:
|
||
sizeY = sizeY + 1 if random.random() < 0.5 else sizeY - 1
|
||
target = cv2.GaussianBlur(target, (sizeX, sizeY), 0)
|
||
|
||
target = torch.FloatTensor(target) # [heatmap_channel, H, W]
|
||
cond = torch.FloatTensor(heatmap) # [heatmap_channel, H, W]
|
||
C, H, W = target.shape
|
||
|
||
for i in range(C):
|
||
total = H * W
|
||
ratio = np.random.random() * (self.max_mask - self.min_mask) + self.min_mask
|
||
num = int(total * ratio)
|
||
|
||
idx = np.random.choice(total, num, replace=False)
|
||
|
||
mask = np.zeros(total, dtype=bool)
|
||
mask[idx] = True
|
||
mask = mask.reshape(H, W)
|
||
cond[i, mask] = 0
|
||
|
||
return {
|
||
"target_heatmap": heatmap,
|
||
"cond_heatmap": cond
|
||
}
|
||
|
||
|
||
class GinkaJointDataset(Dataset):
|
||
def __init__(self, data_path: str, min_mask=0, max_mask=0.8, blur_min=3, blur_max=6):
|
||
self.data = load_data(data_path)
|
||
self.blur_min = blur_min
|
||
self.blur_max = blur_max
|
||
self.min_mask = min_mask
|
||
self.max_mask = max_mask
|
||
|
||
def __len__(self):
|
||
return len(self.data)
|
||
|
||
def __getitem__(self, idx):
|
||
item = self.data[idx]
|
||
|
||
target_map = np.array(item['map'])
|
||
heatmap = np.array(item['heatmap'], dtype=np.float32)
|
||
|
||
if np.random.rand() > 0.5:
|
||
k = np.random.randint(0, 4)
|
||
target_map = np.rot90(target_map, k)
|
||
for i in range(0, heatmap.shape[0]):
|
||
heatmap[i] = np.rot90(heatmap[i], k)
|
||
|
||
if np.random.rand() > 0.5:
|
||
target_map = np.fliplr(target_map)
|
||
for i in range(0, heatmap.shape[0]):
|
||
heatmap[i] = np.fliplr(heatmap[i])
|
||
|
||
if np.random.rand() > 0.5:
|
||
target_map = np.flipud(target_map)
|
||
for i in range(0, heatmap.shape[0]):
|
||
heatmap[i] = np.flipud(heatmap[i])
|
||
|
||
target_heatmap = heatmap.copy()
|
||
|
||
if random.random() < 0.5:
|
||
size = random.randint(self.blur_min, self.blur_max)
|
||
if size % 2 == 0:
|
||
size = size + 1 if random.random() < 0.5 else size - 1
|
||
target_heatmap = cv2.GaussianBlur(target_heatmap, (size, size), 0)
|
||
else:
|
||
sizeX = random.randint(self.blur_min, self.blur_max)
|
||
sizeY = random.randint(self.blur_min, self.blur_max)
|
||
if sizeX % 2 == 0:
|
||
sizeX = sizeX + 1 if random.random() < 0.5 else sizeX - 1
|
||
if sizeY % 2 == 0:
|
||
sizeY = sizeY + 1 if random.random() < 0.5 else sizeY - 1
|
||
target_heatmap = cv2.GaussianBlur(target_heatmap, (sizeX, sizeY), 0)
|
||
|
||
target_map = torch.LongTensor(target_map.copy())
|
||
target_heatmap = torch.FloatTensor(target_heatmap)
|
||
cond_heatmap = torch.FloatTensor(heatmap.copy())
|
||
channels, height, width = cond_heatmap.shape
|
||
|
||
for i in range(channels):
|
||
total = height * width
|
||
ratio = np.random.random() * (self.max_mask - self.min_mask) + self.min_mask
|
||
num = int(total * ratio)
|
||
|
||
masked_indices = np.random.choice(total, num, replace=False)
|
||
|
||
mask = np.zeros(total, dtype=bool)
|
||
mask[masked_indices] = True
|
||
mask = mask.reshape(height, width)
|
||
cond_heatmap[i, mask] = 0
|
||
|
||
return {
|
||
"target_map": target_map,
|
||
"target_heatmap": target_heatmap,
|
||
"cond_heatmap": cond_heatmap
|
||
}
|
||
|
||
|
||
def _compute_symmetry(target_np: np.ndarray) -> tuple:
|
||
"""从 numpy 地图矩阵中直接计算三种对称性,O(H*W)"""
|
||
sym_h = bool(np.all(target_np == target_np[:, ::-1]))
|
||
sym_v = bool(np.all(target_np == target_np[::-1, :]))
|
||
sym_c = bool(np.all(target_np == target_np[::-1, ::-1]))
|
||
return int(sym_h), int(sym_v), int(sym_c)
|
||
|
||
|
||
class GinkaVQDataset(Dataset):
|
||
"""
|
||
用于 VQ-VAE + MaskGIT 联合训练的多子集数据集。
|
||
|
||
每次 __getitem__ 按权重随机选取以下四种子集之一:
|
||
A (standard): 标准 MaskGIT 随机掩码,随机遮盖部分 tile
|
||
B (wall-only): 仅保留 wall(1) + floor(0),其余全部替换为 MASK(15)
|
||
C (wall-random): 在 B 基础上,再随机 mask 部分 wall tile
|
||
D (wall+entry): 仅保留 wall(1) + floor(0) + entrance(10),其余全部替换为 MASK(15)
|
||
|
||
返回 dict:
|
||
raw_map: LongTensor [H*W] 完整原始地图(供 VQ-VAE 编码)
|
||
masked_map: LongTensor [H*W] MaskGIT 输入(被 mask 的位置 = 15)
|
||
target_map: LongTensor [H*W] CE loss ground truth(等同 raw_map)
|
||
subset: str 子集标识,供调试/统计用
|
||
"""
|
||
|
||
FLOOR = 0
|
||
WALL = 1
|
||
ENTRANCE = 10
|
||
MASK_ID = 15
|
||
|
||
def __init__(
|
||
self,
|
||
data_path: str,
|
||
subset_weights: tuple = (0.5, 0.2, 0.2, 0.1),
|
||
wall_mask_ratio: float = 0.3,
|
||
room_thresholds: tuple = None,
|
||
branch_thresholds: tuple = None,
|
||
):
|
||
"""
|
||
Args:
|
||
data_path: JSON 数据文件路径
|
||
subset_weights: 子集 (A, B, C, D) 的采样权重,自动归一化
|
||
wall_mask_ratio: Subset C 中额外随机 mask 的 wall tile 比例上限
|
||
(每次从 [0, wall_mask_ratio] 均匀采样实际比例)
|
||
room_thresholds: (th1, th2) 房间数量等频分箱阈值;为 None 时自动从当前数据计算(训练集)
|
||
branch_thresholds: (th1, th2) 分支数量等频分箱阈值;为 None 时自动从当前数据计算(训练集)
|
||
"""
|
||
self.data = load_data(data_path)
|
||
self.wall_mask_ratio = wall_mask_ratio
|
||
|
||
# 累积权重,用于快速随机子集选择
|
||
total_w = sum(subset_weights)
|
||
normalized = [x / total_w for x in subset_weights]
|
||
self.subset_cumw = [sum(normalized[:i + 1]) for i in range(len(normalized))]
|
||
|
||
# ── 两趟扫描:计算等频分箱阈值 ──────────────────────────────
|
||
room_counts = [item['roomCount'] for item in self.data]
|
||
branch_counts = [item['highDegBranchCount'] for item in self.data]
|
||
|
||
if room_thresholds is None:
|
||
n = len(room_counts)
|
||
rs = sorted(room_counts)
|
||
bs = sorted(branch_counts)
|
||
th1_r, th2_r = rs[n // 3], rs[2 * n // 3]
|
||
th1_b, th2_b = bs[n // 3], bs[2 * n // 3]
|
||
# 防止 Medium 等级为空
|
||
if th1_r == th2_r:
|
||
th2_r = th1_r + 1
|
||
if th1_b == th2_b:
|
||
th2_b = th1_b + 1
|
||
self.room_th = (th1_r, th2_r)
|
||
self.branch_th = (th1_b, th2_b)
|
||
else:
|
||
self.room_th = room_thresholds
|
||
self.branch_th = branch_thresholds
|
||
|
||
def to_level(v: int, th: tuple) -> int:
|
||
return 0 if v < th[0] else (1 if v < th[1] else 2)
|
||
|
||
# 回填等级字段
|
||
for item in self.data:
|
||
item['roomCountLevel'] = to_level(item['roomCount'], self.room_th)
|
||
item['branchLevel'] = to_level(item['highDegBranchCount'], self.branch_th)
|
||
|
||
def __len__(self):
|
||
return len(self.data)
|
||
|
||
# ------------------------------------------------------------------
|
||
# 内联随机掩码生成(避免 scipy 的 NumPy 版本兼容问题)
|
||
# ------------------------------------------------------------------
|
||
@staticmethod
|
||
def _sample_mask_ratio(min_r=0.05, max_r=1.0) -> float:
|
||
"""用 Beta(2,2) 分布采样掩码比例,集中在中间值。"""
|
||
r = np.random.beta(2, 2)
|
||
return min_r + (max_r - min_r) * r
|
||
|
||
@staticmethod
|
||
def _random_mask(h: int, w: int) -> np.ndarray:
|
||
"""纯随机掩码,返回 [H*W] bool。"""
|
||
ratio = GinkaVQDataset._sample_mask_ratio()
|
||
total = h * w
|
||
idx = np.random.choice(total, int(total * ratio), replace=False)
|
||
mask = np.zeros(total, dtype=bool)
|
||
mask[idx] = True
|
||
return mask
|
||
|
||
@staticmethod
|
||
def _block_mask(h: int, w: int) -> np.ndarray:
|
||
"""矩形分块随机掩码,返回 [H*W] bool。"""
|
||
ratio = GinkaVQDataset._sample_mask_ratio()
|
||
max_block = max(2, min(h, w) // 2)
|
||
target = int(h * w * ratio)
|
||
mask = np.zeros((h, w), dtype=bool)
|
||
while mask.sum() < target:
|
||
bh = np.random.randint(2, max_block + 1)
|
||
bw = np.random.randint(2, max_block + 1)
|
||
x = np.random.randint(0, max(1, h - bh + 1))
|
||
y = np.random.randint(0, max(1, w - bw + 1))
|
||
mask[x:x + bh, y:y + bw] = True
|
||
return mask.reshape(-1)
|
||
|
||
def _std_mask(self, h: int, w: int) -> np.ndarray:
|
||
"""标准 MaskGIT 掩码:随机选择纯随机或分块策略。"""
|
||
if random.random() < 0.5:
|
||
return self._random_mask(h, w)
|
||
else:
|
||
return self._block_mask(h, w)
|
||
|
||
# ------------------------------------------------------------------
|
||
|
||
def _augment(self, arr: np.ndarray) -> np.ndarray:
|
||
"""随机旋转 / 翻转数据增强,返回新 array。"""
|
||
if np.random.rand() > 0.5:
|
||
k = np.random.randint(1, 4)
|
||
arr = np.rot90(arr, k).copy()
|
||
if np.random.rand() > 0.5:
|
||
arr = np.fliplr(arr).copy()
|
||
if np.random.rand() > 0.5:
|
||
arr = np.flipud(arr).copy()
|
||
return arr
|
||
|
||
def _choose_subset(self) -> str:
|
||
r = random.random()
|
||
if r < self.subset_cumw[0]:
|
||
return 'A'
|
||
elif r < self.subset_cumw[1]:
|
||
return 'B'
|
||
elif r < self.subset_cumw[2]:
|
||
return 'C'
|
||
else:
|
||
return 'D'
|
||
|
||
def _apply_subset(self, raw: np.ndarray, subset: str) -> np.ndarray:
|
||
"""
|
||
根据子集类型生成 masked_map。
|
||
|
||
Args:
|
||
raw: [H, W] int64 完整原始地图
|
||
subset: 'A' | 'B' | 'C' | 'D'
|
||
|
||
Returns:
|
||
[H*W] int64,被遮盖位置值为 MASK_ID(15)
|
||
"""
|
||
H, W = raw.shape
|
||
|
||
if subset == 'A':
|
||
# 标准随机 mask:纯随机或分块策略
|
||
mask = self._std_mask(H, W) # [H*W] bool
|
||
flat = raw.reshape(-1).copy()
|
||
flat[mask] = self.MASK_ID
|
||
return flat
|
||
|
||
elif subset == 'B':
|
||
# 仅保留 wall(1),floor(0) 和其他非墙内容全部 mask
|
||
flat = raw.reshape(-1).copy()
|
||
keep = (flat == self.WALL)
|
||
flat[~keep] = self.MASK_ID
|
||
return flat
|
||
|
||
elif subset == 'C':
|
||
# Subset B + 随机 mask 部分 wall
|
||
flat = raw.reshape(-1).copy()
|
||
keep = (flat == self.WALL)
|
||
flat[~keep] = self.MASK_ID
|
||
|
||
wall_idx = np.where(flat == self.WALL)[0]
|
||
if len(wall_idx) > 0:
|
||
ratio = random.random() * self.wall_mask_ratio
|
||
n = max(1, int(len(wall_idx) * ratio))
|
||
chosen = np.random.choice(wall_idx, n, replace=False)
|
||
flat[chosen] = self.MASK_ID
|
||
return flat
|
||
|
||
else: # D
|
||
# 仅保留 wall(1) 和 entrance(10),floor(0) 和其他非墙内容全部 mask
|
||
flat = raw.reshape(-1).copy()
|
||
keep = (flat == self.WALL) | (flat == self.ENTRANCE)
|
||
flat[~keep] = self.MASK_ID
|
||
|
||
# 随机 mask 部分 wall(模拟真实场景,与子集 C 一致)
|
||
wall_idx = np.where(flat == self.WALL)[0]
|
||
if len(wall_idx) > 0:
|
||
ratio = random.random() * self.wall_mask_ratio
|
||
n = max(1, int(len(wall_idx) * ratio))
|
||
chosen = np.random.choice(wall_idx, n, replace=False)
|
||
flat[chosen] = self.MASK_ID
|
||
return flat
|
||
|
||
def __getitem__(self, idx):
|
||
item = self.data[idx]
|
||
|
||
raw_np = self._augment(np.array(item['map'], dtype=np.int64)) # [H, W]
|
||
raw_np = remap_resources(raw_np) # 资源压缩
|
||
subset = self._choose_subset()
|
||
masked_np = self._apply_subset(raw_np, subset) # [H*W]
|
||
raw_flat = raw_np.reshape(-1) # [H*W]
|
||
|
||
# 对称性:在增强后重新计算
|
||
sym_h, sym_v, sym_c = _compute_symmetry(raw_np)
|
||
cond_sym = sym_h * 4 + sym_v * 2 + sym_c # [0, 7]
|
||
|
||
# 其余结构标签:增强不改变拓扑结构,直接读取
|
||
cond_room = item['roomCountLevel'] # 0/1/2
|
||
cond_branch = item['branchLevel'] # 0/1/2
|
||
cond_outer = item['outerWall'] # 0/1
|
||
|
||
struct_cond = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer])
|
||
|
||
raw_t = torch.LongTensor(raw_flat)
|
||
return {
|
||
"raw_map": raw_t, # VQ-VAE 编码器输入
|
||
"slice1": make_slice(raw_t, {0, 1}), # 通道 1:floor+wall
|
||
"slice2": make_slice(raw_t, {0, 1, 2, 9, 10}),# 通道 2:floor+wall+门+怪+入口
|
||
"slice3": raw_t.clone(), # 通道 3:完整地图
|
||
"masked_map": torch.LongTensor(masked_np), # MaskGIT 输入
|
||
"target_map": torch.LongTensor(raw_flat.copy()), # CE loss ground truth
|
||
"subset": subset, # 调试/统计用
|
||
"struct_cond": struct_cond, # [4],供模型 Embedding 查表
|
||
}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# make_slice:按保留集合切割地图,其余位置替换为 floor(0)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def make_slice(map_flat: torch.Tensor, keep_set: set) -> torch.Tensor:
|
||
"""
|
||
从完整地图中只保留 keep_set 中的 tile 类型,其余位置替换为 floor(0)。
|
||
|
||
Args:
|
||
map_flat: LongTensor [H*W] 完整地图 tile ID 序列
|
||
keep_set: set of int 需要保留的 tile 类型集合
|
||
|
||
Returns:
|
||
LongTensor [H*W] 切片后的地图(非保留 tile 位置值为 0)
|
||
"""
|
||
out = map_flat.clone()
|
||
mask = torch.zeros_like(out, dtype=torch.bool)
|
||
for t in keep_set:
|
||
mask |= (out == t)
|
||
out[~mask] = 0
|
||
return out
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# GinkaSplitDataset:三通道分拆预训练专用数据集
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class GinkaSplitDataset(Dataset):
|
||
"""
|
||
三通道分拆预训练(方案 B)专用数据集。
|
||
|
||
每个样本只提供完整地图及其三路切片,不做 MaskGIT 掩码处理。
|
||
切片按累积式设计:
|
||
slice1 = floor(0) + wall(1)
|
||
slice2 = floor(0) + wall(1) + door(2) + mob(9) + entrance(10)
|
||
slice3 = 完整地图(所有 tile)
|
||
|
||
返回 dict:
|
||
raw_map: LongTensor [H*W] 完整原始地图
|
||
slice1: LongTensor [H*W] 通道 1 切片(floor+wall)
|
||
slice2: LongTensor [H*W] 通道 2 切片(floor+wall+门+怪+入口)
|
||
slice3: LongTensor [H*W] 通道 3 切片(完整地图)
|
||
"""
|
||
|
||
def __init__(self, data_path: str):
|
||
self.data = load_data(data_path)
|
||
|
||
def __len__(self):
|
||
return len(self.data)
|
||
|
||
def __getitem__(self, idx):
|
||
item = self.data[idx]
|
||
arr = np.array(item['map'], dtype=np.int64) # [H, W]
|
||
arr = remap_resources(arr) # 资源压缩
|
||
|
||
# 随机旋转 / 翻转数据增强
|
||
if np.random.rand() > 0.5:
|
||
k = np.random.randint(1, 4)
|
||
arr = np.rot90(arr, k).copy()
|
||
if np.random.rand() > 0.5:
|
||
arr = np.fliplr(arr).copy()
|
||
if np.random.rand() > 0.5:
|
||
arr = np.flipud(arr).copy()
|
||
|
||
raw = torch.LongTensor(arr.reshape(-1)) # [H*W]
|
||
return {
|
||
"raw_map": raw,
|
||
"slice1": make_slice(raw, {0, 1}),
|
||
"slice2": make_slice(raw, {0, 1, 2, 9, 10}),
|
||
"slice3": raw.clone(),
|
||
}
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import os
|
||
data_path = os.path.join(os.path.dirname(__file__), '..', 'data', 'ginka-dataset.json')
|
||
ds = GinkaVQDataset(data_path)
|
||
print(f"数据集大小: {len(ds)}")
|
||
|
||
subset_count = {'A': 0, 'B': 0, 'C': 0, 'D': 0}
|
||
for i in range(200):
|
||
sample = ds[i % len(ds)]
|
||
subset_count[sample['subset']] += 1
|
||
|
||
raw = sample['raw_map']
|
||
masked = sample['masked_map']
|
||
target = sample['target_map']
|
||
print(f"raw_map shape={raw.shape}, dtype={raw.dtype}")
|
||
print(f"masked_map shape={masked.shape}, dtype={masked.dtype}")
|
||
print(f"target_map shape={target.shape}, dtype={target.dtype}")
|
||
print(f"被 mask 的位置数: {(masked == 15).sum().item()} / {masked.numel()}")
|
||
print(f"\n200 次采样子集分布: {subset_count}")
|