ginka-generator/ginka/dataset.py
unanmed 3676958781 feat: 分三阶段训练
Co-authored-by: Copilot <copilot@github.com>
2026-05-07 20:59:22 +08:00

719 lines
29 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import random
import torch
import numpy as np
from torch.utils.data import Dataset
def _compute_map_labels(map_2d) -> dict:
"""
从 2D 地图列表(或 numpy 数组)推算结构标签。
当 JSON 数据缺少 roomCount / highDegBranchCount / outerWall 字段时调用。
"""
arr = np.array(map_2d, dtype=np.int64) # [H, W]
H, W = arr.shape
WALL, ENTRY = 1, 5
# outerWall最外圈中 wall+entry 占比 > 90%
border = np.concatenate([arr[0, :], arr[-1, :], arr[1:-1, 0], arr[1:-1, -1]])
total_b = border.size
outer_wall = int(total_b > 0 and np.sum((border == WALL) | (border == ENTRY)) / total_b > 0.9)
# roomCountBFS 统计 floor(0)+resource(3) 连通区域,
# 需满足:总格子 >= 4外接矩形宽 >= 2 且高 >= 2
FLOOR_SET = (0, 3)
visited = np.zeros((H, W), dtype=bool)
room_count = 0
for sy in range(H):
for sx in range(W):
if arr[sy, sx] not in FLOOR_SET or visited[sy, sx]:
continue
queue = [(sy, sx)]
visited[sy, sx] = True
tiles_y, tiles_x = [sy], [sx]
head = 0
while head < len(queue):
y, x = queue[head]; head += 1
for dy, dx in ((-1, 0), (1, 0), (0, -1), (0, 1)):
ny, nx = y + dy, x + dx
if 0 <= ny < H and 0 <= nx < W and not visited[ny, nx] and arr[ny, nx] in FLOOR_SET:
visited[ny, nx] = True
queue.append((ny, nx))
tiles_y.append(ny); tiles_x.append(nx)
if (len(tiles_y) >= 4
and max(tiles_y) - min(tiles_y) >= 1
and max(tiles_x) - min(tiles_x) >= 1):
room_count += 1
# highDegBranchCount非 wall 格子中4 邻域非 wall 邻居 >= 3 的数量
non_wall = (arr != WALL).astype(np.int32)
padded = np.pad(non_wall, 1, mode='constant', constant_values=0)
nbr_sum = (padded[:-2, 1:-1] + padded[2:, 1:-1] +
padded[1:-1, :-2] + padded[1:-1, 2:])
high_deg = int(np.sum((non_wall == 1) & (nbr_sum >= 3)))
return {'outerWall': outer_wall, 'roomCount': room_count, 'highDegBranchCount': high_deg}
def load_data(path: str):
with open(path, 'r', encoding="utf-8") as f:
data = json.load(f)
data_list = []
for value in data["data"].values():
# 兼容旧版数据集(缺少结构标签字段)
if 'roomCount' not in value:
labels = _compute_map_labels(value['map'])
value.update(labels)
# symmetry 字段由 __getitem__ 在增强后重新计算,此处不需要从 JSON 读取
data_list.append(value)
return data_list
def _compute_symmetry(target_np: np.ndarray) -> tuple:
"""从 numpy 地图矩阵中直接计算三种对称性O(H*W)"""
sym_h = bool(np.all(target_np == target_np[:, ::-1]))
sym_v = bool(np.all(target_np == target_np[::-1, :]))
sym_c = bool(np.all(target_np == target_np[::-1, ::-1]))
return int(sym_h), int(sym_v), int(sym_c)
class GinkaVQDataset(Dataset):
"""
用于 VQ-VAE + MaskGIT 联合训练的多子集数据集。
每次 __getitem__ 按权重随机选取以下四种子集之一:
A (standard): 标准 MaskGIT 随机掩码,随机遮盖部分 tile
B (wall-only): 仅保留 wall(1) + floor(0),其余全部替换为 MASK(6)
C (wall-random): 在 B 基础上,再随机 mask 部分 wall tile
D (wall+entry): 仅保留 wall(1) + floor(0) + entrance(5),其余全部替换为 MASK(6)
返回 dict:
raw_map: LongTensor [H*W] 完整原始地图(供 VQ-VAE 编码)
masked_map: LongTensor [H*W] MaskGIT 输入(被 mask 的位置 = 6
target_map: LongTensor [H*W] CE loss ground truth等同 raw_map
subset: str 子集标识,供调试/统计用
"""
FLOOR = 0
WALL = 1
ENTRANCE = 5
MASK_ID = 6
def __init__(
self,
data_path: str,
subset_weights: tuple = (0.5, 0.2, 0.2, 0.1),
wall_mask_ratio: float = 0.3,
room_thresholds: tuple = None,
branch_thresholds: tuple = None,
):
"""
Args:
data_path: JSON 数据文件路径
subset_weights: 子集 (A, B, C, D) 的采样权重,自动归一化
wall_mask_ratio: Subset C 中额外随机 mask 的 wall tile 比例上限
(每次从 [0, wall_mask_ratio] 均匀采样实际比例)
room_thresholds: (th1, th2) 房间数量等频分箱阈值;为 None 时自动从当前数据计算(训练集)
branch_thresholds: (th1, th2) 分支数量等频分箱阈值;为 None 时自动从当前数据计算(训练集)
"""
self.data = load_data(data_path)
self.wall_mask_ratio = wall_mask_ratio
# 累积权重,用于快速随机子集选择
total_w = sum(subset_weights)
normalized = [x / total_w for x in subset_weights]
self.subset_cumw = [sum(normalized[:i + 1]) for i in range(len(normalized))]
# ── 两趟扫描:计算等频分箱阈值 ──────────────────────────────
room_counts = [item['roomCount'] for item in self.data]
branch_counts = [item['highDegBranchCount'] for item in self.data]
if room_thresholds is None:
n = len(room_counts)
rs = sorted(room_counts)
bs = sorted(branch_counts)
th1_r, th2_r = rs[n // 3], rs[2 * n // 3]
th1_b, th2_b = bs[n // 3], bs[2 * n // 3]
# 防止 Medium 等级为空
if th1_r == th2_r:
th2_r = th1_r + 1
if th1_b == th2_b:
th2_b = th1_b + 1
self.room_th = (th1_r, th2_r)
self.branch_th = (th1_b, th2_b)
else:
self.room_th = room_thresholds
self.branch_th = branch_thresholds
def to_level(v: int, th: tuple) -> int:
return 0 if v < th[0] else (1 if v < th[1] else 2)
# 回填等级字段
for item in self.data:
item['roomCountLevel'] = to_level(item['roomCount'], self.room_th)
item['branchLevel'] = to_level(item['highDegBranchCount'], self.branch_th)
def __len__(self):
return len(self.data)
# ------------------------------------------------------------------
# 内联随机掩码生成(避免 scipy 的 NumPy 版本兼容问题)
# ------------------------------------------------------------------
@staticmethod
def _sample_mask_ratio(min_r=0.05, max_r=1.0) -> float:
"""用 Beta(2,2) 分布采样掩码比例,集中在中间值。"""
r = np.random.beta(2, 2)
return min_r + (max_r - min_r) * r
@staticmethod
def _random_mask(h: int, w: int) -> np.ndarray:
"""纯随机掩码,返回 [H*W] bool。"""
ratio = GinkaVQDataset._sample_mask_ratio()
total = h * w
idx = np.random.choice(total, int(total * ratio), replace=False)
mask = np.zeros(total, dtype=bool)
mask[idx] = True
return mask
@staticmethod
def _block_mask(h: int, w: int) -> np.ndarray:
"""矩形分块随机掩码,返回 [H*W] bool。"""
ratio = GinkaVQDataset._sample_mask_ratio()
max_block = max(2, min(h, w) // 2)
target = int(h * w * ratio)
mask = np.zeros((h, w), dtype=bool)
while mask.sum() < target:
bh = np.random.randint(2, max_block + 1)
bw = np.random.randint(2, max_block + 1)
x = np.random.randint(0, max(1, h - bh + 1))
y = np.random.randint(0, max(1, w - bw + 1))
mask[x:x + bh, y:y + bw] = True
return mask.reshape(-1)
def _std_mask(self, h: int, w: int) -> np.ndarray:
"""标准 MaskGIT 掩码:随机选择纯随机或分块策略。"""
if random.random() < 0.5:
return self._random_mask(h, w)
else:
return self._block_mask(h, w)
# ------------------------------------------------------------------
def _augment(self, arr: np.ndarray) -> np.ndarray:
"""随机旋转 / 翻转数据增强,返回新 array。"""
if np.random.rand() > 0.5:
k = np.random.randint(1, 4)
arr = np.rot90(arr, k).copy()
if np.random.rand() > 0.5:
arr = np.fliplr(arr).copy()
if np.random.rand() > 0.5:
arr = np.flipud(arr).copy()
return arr
def _choose_subset(self) -> str:
r = random.random()
if r < self.subset_cumw[0]:
return 'A'
elif r < self.subset_cumw[1]:
return 'B'
elif r < self.subset_cumw[2]:
return 'C'
else:
return 'D'
def _apply_subset(self, raw: np.ndarray, subset: str) -> np.ndarray:
"""
根据子集类型生成 masked_map。
Args:
raw: [H, W] int64 完整原始地图
subset: 'A' | 'B' | 'C' | 'D'
Returns:
[H*W] int64被遮盖位置値为 MASK_ID(6)
"""
H, W = raw.shape
if subset == 'A':
# 标准随机 mask纯随机或分块策略
mask = self._std_mask(H, W) # [H*W] bool
flat = raw.reshape(-1).copy()
flat[mask] = self.MASK_ID
return flat
elif subset == 'B':
# 仅保留 wall(1)floor(0) 和其他非墙内容全部 mask
flat = raw.reshape(-1).copy()
keep = (flat == self.WALL)
flat[~keep] = self.MASK_ID
return flat
elif subset == 'C':
# Subset B + 随机 mask 部分 wall
flat = raw.reshape(-1).copy()
keep = (flat == self.WALL)
flat[~keep] = self.MASK_ID
wall_idx = np.where(flat == self.WALL)[0]
if len(wall_idx) > 0:
ratio = random.random() * self.wall_mask_ratio
n = max(1, int(len(wall_idx) * ratio))
chosen = np.random.choice(wall_idx, n, replace=False)
flat[chosen] = self.MASK_ID
return flat
else: # D
# 仅保留 wall(1) 和 entrance(5)floor(0) 和其他非墙内容全部 mask
flat = raw.reshape(-1).copy()
keep = (flat == self.WALL) | (flat == self.ENTRANCE)
flat[~keep] = self.MASK_ID
# 随机 mask 部分 wall模拟真实场景与子集 C 一致)
wall_idx = np.where(flat == self.WALL)[0]
if len(wall_idx) > 0:
ratio = random.random() * self.wall_mask_ratio
n = max(1, int(len(wall_idx) * ratio))
chosen = np.random.choice(wall_idx, n, replace=False)
flat[chosen] = self.MASK_ID
return flat
def __getitem__(self, idx):
item = self.data[idx]
raw_np = self._augment(np.array(item['map'], dtype=np.int64)) # [H, W]
subset = self._choose_subset()
masked_np = self._apply_subset(raw_np, subset) # [H*W]
raw_flat = raw_np.reshape(-1) # [H*W]
# 对称性:在增强后重新计算
sym_h, sym_v, sym_c = _compute_symmetry(raw_np)
cond_sym = sym_h * 4 + sym_v * 2 + sym_c # [0, 7]
# 其余结构标签:增强不改变拓扑结构,直接读取
cond_room = item['roomCountLevel'] # 0/1/2
cond_branch = item['branchLevel'] # 0/1/2
cond_outer = item['outerWall'] # 0/1
struct_cond = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer])
raw_t = torch.LongTensor(raw_flat)
return {
"raw_map": raw_t, # VQ-VAE 编码器输入
"slice1": make_slice(raw_t, {0, 1}), # 通道 1floor+wall
"slice2": make_slice(raw_t, {0, 1, 2, 4, 5}), # 通道 2floor+wall+门+怪+入口
"slice3": raw_t.clone(), # 通道 3完整地图
"masked_map": torch.LongTensor(masked_np), # MaskGIT 输入
"target_map": torch.LongTensor(raw_flat.copy()), # CE loss ground truth
"subset": subset, # 调试/统计用
"struct_cond": struct_cond, # [4],供模型 Embedding 查表
}
# ---------------------------------------------------------------------------
# make_slice按保留集合切割地图其余位置替换为 floor(0)
# ---------------------------------------------------------------------------
def make_slice(map_flat: torch.Tensor, keep_set: set) -> torch.Tensor:
"""
从完整地图中只保留 keep_set 中的 tile 类型,其余位置替换为 floor(0)。
Args:
map_flat: LongTensor [H*W] 完整地图 tile ID 序列
keep_set: set of int 需要保留的 tile 类型集合
Returns:
LongTensor [H*W] 切片后的地图(非保留 tile 位置值为 0
"""
out = map_flat.clone()
mask = torch.zeros_like(out, dtype=torch.bool)
for t in keep_set:
mask |= (out == t)
out[~mask] = 0
return out
# ---------------------------------------------------------------------------
# GinkaSplitDataset三通道分拆预训练专用数据集
# ---------------------------------------------------------------------------
class GinkaSplitDataset(Dataset):
"""
三通道分拆预训练(方案 B专用数据集。
每个样本只提供完整地图及其三路切片,不做 MaskGIT 掩码处理。
切片按累积式设计:
slice1 = floor(0) + wall(1)
slice2 = floor(0) + wall(1) + door(2) + mob(4) + entrance(5)
slice3 = 完整地图(所有 tile
返回 dict:
raw_map: LongTensor [H*W] 完整原始地图
slice1: LongTensor [H*W] 通道 1 切片floor+wall
slice2: LongTensor [H*W] 通道 2 切片floor+wall+门+怪+入口)
slice3: LongTensor [H*W] 通道 3 切片(完整地图)
"""
def __init__(self, data_path: str):
self.data = load_data(data_path)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
arr = np.array(item['map'], dtype=np.int64) # [H, W]
# 随机旋转 / 翻转数据增强
if np.random.rand() > 0.5:
k = np.random.randint(1, 4)
arr = np.rot90(arr, k).copy()
if np.random.rand() > 0.5:
arr = np.fliplr(arr).copy()
if np.random.rand() > 0.5:
arr = np.flipud(arr).copy()
raw = torch.LongTensor(arr.reshape(-1)) # [H*W]
return {
"raw_map": raw,
"slice1": make_slice(raw, {0, 1}),
"slice2": make_slice(raw, {0, 1, 2, 4, 5}),
"slice3": raw.clone(),
}
# ---------------------------------------------------------------------------
# GinkaStageDataset三阶段级联训练专用数据集
# ---------------------------------------------------------------------------
class GinkaStageDataset(Dataset):
"""
三阶段级联生成训练专用 Dataset。
每个阶段只预测特定类别的 tile后续阶段以前序阶段输出作为上下文。
训练时统一使用 GT 作为前序上下文teacher forcing避免误差级联。
阶段划分:
stage=1 结构骨架:预测 floor(0) + wall(1)
stage=2 功能元素:预测 door(2) + monster(4) + entrance(5),以 floor/wall 为上下文
stage=3 资源放置:预测 resource(3),以完整骨架为上下文
返回 dict:
raw_map: LongTensor [H*W] 完整原始地图(供 VQ-VAE 编码)
vq_slice: LongTensor [H*W] 当前阶段 VQ 编码器的输入切片
stage_input: LongTensor [H*W] MaskGIT 输入(含上下文 + MASK 位置)
target_map: LongTensor [H*W] CE loss ground truth
loss_mask: BoolTensor [H*W] 只对 True 位置计算损失
subset: str 子集标识 A/B/C/D
struct_cond: LongTensor [4] [sym, room, branch, outer]
"""
FLOOR = 0
WALL = 1
DOOR = 2
RESOURCE = 3
MONSTER = 4
ENTRANCE = 5
MASK_ID = 6
STAGE1_TARGETS = frozenset({0, 1})
STAGE2_TARGETS = frozenset({2, 4, 5})
STAGE3_TARGETS = frozenset({3})
# VQ 切片集合:各阶段编码器只"看"与自身相关的 tile
_VQ_KEEP = {
1: frozenset({0, 1}),
2: frozenset({0, 1, 2, 4, 5}),
3: None, # 完整地图
}
def __init__(
self,
data_path: str,
stage: int,
subset_weights: tuple = (0.5, 0.2, 0.2, 0.1),
wall_mask_ratio: float = 0.3,
room_thresholds: tuple = None,
branch_thresholds: tuple = None,
):
"""
Args:
data_path: JSON 数据文件路径
stage: 生成阶段 1/2/3
subset_weights: 子集 (A, B, C, D) 的采样权重,自动归一化
wall_mask_ratio: Subset C 中额外随机 mask 的 wall 比例上限
room_thresholds: 等频分箱阈值None 时自动计算)
branch_thresholds: 等频分箱阈值None 时自动计算)
"""
assert stage in (1, 2, 3), f"stage 必须是 1/2/3收到 {stage}"
self.stage = stage
self.data = load_data(data_path)
self.wall_mask_ratio = wall_mask_ratio
total_w = sum(subset_weights)
normalized = [x / total_w for x in subset_weights]
self.subset_cumw = [sum(normalized[:i + 1]) for i in range(len(normalized))]
room_counts = [item['roomCount'] for item in self.data]
branch_counts = [item['highDegBranchCount'] for item in self.data]
if room_thresholds is None:
n = len(room_counts)
rs = sorted(room_counts)
bs = sorted(branch_counts)
th1_r, th2_r = rs[n // 3], rs[2 * n // 3]
th1_b, th2_b = bs[n // 3], bs[2 * n // 3]
if th1_r == th2_r: th2_r = th1_r + 1
if th1_b == th2_b: th2_b = th1_b + 1
self.room_th = (th1_r, th2_r)
self.branch_th = (th1_b, th2_b)
else:
self.room_th = room_thresholds
self.branch_th = branch_thresholds
def to_level(v, th):
return 0 if v < th[0] else (1 if v < th[1] else 2)
for item in self.data:
item['roomCountLevel'] = to_level(item['roomCount'], self.room_th)
item['branchLevel'] = to_level(item['highDegBranchCount'], self.branch_th)
def __len__(self):
return len(self.data)
# ------------------------------------------------------------------
# 掩码辅助(与 GinkaVQDataset 相同逻辑)
# ------------------------------------------------------------------
@staticmethod
def _sample_mask_ratio(min_r=0.05, max_r=1.0) -> float:
r = np.random.beta(2, 2)
return min_r + (max_r - min_r) * r
@staticmethod
def _random_mask(h: int, w: int) -> np.ndarray:
ratio = GinkaStageDataset._sample_mask_ratio()
total = h * w
idx = np.random.choice(total, int(total * ratio), replace=False)
mask = np.zeros(total, dtype=bool)
mask[idx] = True
return mask
@staticmethod
def _block_mask(h: int, w: int) -> np.ndarray:
ratio = GinkaStageDataset._sample_mask_ratio()
max_block = max(2, min(h, w) // 2)
target = int(h * w * ratio)
mask = np.zeros((h, w), dtype=bool)
while mask.sum() < target:
bh = np.random.randint(2, max_block + 1)
bw = np.random.randint(2, max_block + 1)
x = np.random.randint(0, max(1, h - bh + 1))
y = np.random.randint(0, max(1, w - bw + 1))
mask[x:x + bh, y:y + bw] = True
return mask.reshape(-1)
def _std_mask(self, h: int, w: int) -> np.ndarray:
return self._random_mask(h, w) if random.random() < 0.5 else self._block_mask(h, w)
# ------------------------------------------------------------------
# 子集选择
# ------------------------------------------------------------------
def _choose_subset(self) -> str:
r = random.random()
if r < self.subset_cumw[0]: return 'A'
if r < self.subset_cumw[1]: return 'B'
if r < self.subset_cumw[2]: return 'C'
return 'D'
# ------------------------------------------------------------------
# 阶段一结构骨架floor + wall
# ------------------------------------------------------------------
def _make_stage1(self, raw_flat: np.ndarray, subset: str):
"""
阶段一:预测 floor/wall所有非 floor/wall tile 在目标中重映射为 floor。
子集决定向模型提供多少 wall 作为上下文条件。
"""
H = W = 13
# 目标:非 floor/wall → floor
target = raw_flat.copy()
target[~np.isin(target, [self.FLOOR, self.WALL])] = self.FLOOR
inp = target.copy()
if subset == 'A':
# 标准随机 mask随机遮盖部分 floor/wall
mask = self._std_mask(H, W)
inp[mask] = self.MASK_ID
elif subset == 'B':
# 保留全部 wallMASK floor
inp[inp == self.FLOOR] = self.MASK_ID
elif subset == 'C':
# 随机保留部分 wallMASK 其余(含全部 floor
inp[inp == self.FLOOR] = self.MASK_ID
wall_idx = np.where(inp == self.WALL)[0]
if len(wall_idx) > 0:
ratio = random.random() * self.wall_mask_ratio
n = max(1, int(len(wall_idx) * ratio))
chosen = np.random.choice(wall_idx, n, replace=False)
inp[chosen] = self.MASK_ID
else: # D与 B 相同(阶段一无 entrance 维度)
inp[inp == self.FLOOR] = self.MASK_ID
loss_mask = (inp == self.MASK_ID)
return inp, target, loss_mask
# ------------------------------------------------------------------
# 阶段二功能元素door + monster + entrance
# ------------------------------------------------------------------
def _make_stage2(self, raw_flat: np.ndarray, subset: str):
"""
阶段二:以 floor/wall 为上下文,预测 door/monster/entrance。
resource 在输入与目标中均视为 floor阶段二不负责资源
子集决定 wall 上下文的完整程度与 door/monster/entrance 的掩码方式。
"""
# 目标resource → floor
target = raw_flat.copy()
target[target == self.RESOURCE] = self.FLOOR
# 基础输入resource → floor功能元素先保留再按子集处理
inp = raw_flat.copy()
inp[inp == self.RESOURCE] = self.FLOOR
if subset == 'A':
# 随机遮盖部分 door/monster/entrance部分上下文补全
func_idx = np.where(np.isin(inp, [self.DOOR, self.MONSTER, self.ENTRANCE]))[0]
if len(func_idx) > 0:
ratio = random.random() * 0.8 + 0.2 # 20%~100%
n = max(1, int(len(func_idx) * ratio))
chosen = np.random.choice(func_idx, n, replace=False)
inp[chosen] = self.MASK_ID
else:
# B/C/D全部 door/monster/entrance → MASK
inp[np.isin(inp, [self.DOOR, self.MONSTER, self.ENTRANCE])] = self.MASK_ID
if subset == 'C':
# 额外随机 mask 部分 wall降低 wall 上下文质量)
wall_idx = np.where(inp == self.WALL)[0]
if len(wall_idx) > 0:
ratio = random.random() * self.wall_mask_ratio
n = max(1, int(len(wall_idx) * ratio))
chosen = np.random.choice(wall_idx, n, replace=False)
inp[chosen] = self.MASK_ID
# loss_mask阶段二只对 door/monster/entrance 原始位置计算损失,
# 不对被额外 mask 的 wall 位置计算(它们在 target 中已知为 wall
loss_mask = np.isin(raw_flat, [self.DOOR, self.MONSTER, self.ENTRANCE])
return inp, target, loss_mask
# ------------------------------------------------------------------
# 阶段三资源放置resource
# ------------------------------------------------------------------
def _make_stage3(self, raw_flat: np.ndarray, subset: str):
"""
阶段三:以完整骨架为上下文,预测 resource 位置。
所有 resource 位置在输入中替换为 MASK。
子集 A 随机保留部分 resource 作为上下文(部分补全训练),
其余子集始终 MASK 全部 resource。
"""
target = raw_flat.copy()
inp = raw_flat.copy()
if subset == 'A':
# 随机遮盖部分 resource部分上下文补全
res_idx = np.where(inp == self.RESOURCE)[0]
if len(res_idx) > 0:
ratio = random.random() * 0.8 + 0.2 # 20%~100%
n = max(1, int(len(res_idx) * ratio))
chosen = np.random.choice(res_idx, n, replace=False)
inp[chosen] = self.MASK_ID
else:
pass # 无 resource 时无需处理
else:
# B/C/D全部 resource → MASK
inp[inp == self.RESOURCE] = self.MASK_ID
loss_mask = (inp == self.MASK_ID)
return inp, target, loss_mask
# ------------------------------------------------------------------
# __getitem__
# ------------------------------------------------------------------
def _augment(self, arr: np.ndarray) -> np.ndarray:
if np.random.rand() > 0.5:
k = np.random.randint(1, 4)
arr = np.rot90(arr, k).copy()
if np.random.rand() > 0.5:
arr = np.fliplr(arr).copy()
if np.random.rand() > 0.5:
arr = np.flipud(arr).copy()
return arr
def __getitem__(self, idx):
item = self.data[idx]
raw_np = self._augment(np.array(item['map'], dtype=np.int64)) # [H, W]
raw_flat = raw_np.reshape(-1) # [H*W]
subset = self._choose_subset()
if self.stage == 1:
stage_input_np, target_np, loss_mask_np = self._make_stage1(raw_flat, subset)
elif self.stage == 2:
stage_input_np, target_np, loss_mask_np = self._make_stage2(raw_flat, subset)
else:
stage_input_np, target_np, loss_mask_np = self._make_stage3(raw_flat, subset)
# 若 loss_mask 全为 False如地图中无 resource 时的 stage3
# 退回为全图损失,避免 NaN
if not loss_mask_np.any():
loss_mask_np = np.ones_like(loss_mask_np)
# VQ 切片:当前阶段编码器的输入(仅保留相关 tile
raw_t = torch.LongTensor(raw_flat)
vq_keep = self._VQ_KEEP[self.stage]
if vq_keep is None:
vq_slice = raw_t.clone()
else:
vq_slice = make_slice(raw_t, vq_keep)
# 结构标签
sym_h, sym_v, sym_c = _compute_symmetry(raw_np)
cond_sym = sym_h * 4 + sym_v * 2 + sym_c
cond_room = item['roomCountLevel']
cond_branch = item['branchLevel']
cond_outer = item['outerWall']
struct_cond = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer])
return {
"raw_map": raw_t,
"vq_slice": vq_slice,
"stage_input": torch.LongTensor(stage_input_np),
"target_map": torch.LongTensor(target_np),
"loss_mask": torch.BoolTensor(loss_mask_np),
"subset": subset,
"struct_cond": struct_cond,
}
if __name__ == "__main__":
import os
data_path = os.path.join(os.path.dirname(__file__), '..', 'data', 'ginka-dataset.json')
ds = GinkaVQDataset(data_path)
print(f"数据集大小: {len(ds)}")
subset_count = {'A': 0, 'B': 0, 'C': 0, 'D': 0}
for i in range(200):
sample = ds[i % len(ds)]
subset_count[sample['subset']] += 1
raw = sample['raw_map']
masked = sample['masked_map']
target = sample['target_map']
print(f"raw_map shape={raw.shape}, dtype={raw.dtype}")
print(f"masked_map shape={masked.shape}, dtype={masked.dtype}")
print(f"target_map shape={target.shape}, dtype={target.dtype}")
print(f"被 mask 的位置数: {(masked == 6).sum().item()} / {masked.numel()}")
print(f"\n200 次采样子集分布: {subset_count}")