diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md deleted file mode 100644 index 82fdf57..0000000 --- a/.github/copilot-instructions.md +++ /dev/null @@ -1,30 +0,0 @@ -# Ginka 地图生成器 - Copilot 指引 - -## 项目概述 - -本项目是一个基于深度学习的二维网格状地图生成模型,用于生成魔塔(Magic Tower)类网页游戏地图。 - -- **模型结构**:VQ-VAE 风格编码器 + MaskGIT 解码器 - - VQ-VAE 编码器将完整地图压缩为离散隐变量 z(从 codebook 查得) - - MaskGIT 以 z 为条件,通过迭代掩码预测生成地图 - - 推理时直接随机采样 z,无需用户输入 -- **地图规格**:13×13 格子,7 类图块 -- **目录结构** - - `ginka/` — 模型定义与训练脚本(Python) - - `data/` — 数据预处理(TypeScript,因游戏是网页游戏) - - `docs/` — 设计文档 - - `shared/` — 可视化等共享工具 - -## 重要约束 - -### 训练 -- **不要在当前设备上运行训练**,训练在其他设备上进行 -- 可以运行小规模验证、推理或单步测试,但不要触发完整训练流程 - -### 代码风格 -- **Python**:不使用三引号注释(`"""..."""`),一律改用 `#` 注释;不出现连续空格;遵循 Prettier 风格(缩进 4 空格,行宽 88) -- **TypeScript**:遵循 Prettier 默认风格 - -### 验证与可视化 -- 编写验证代码时,优先输出可视化结果(图片文件),使用 `shared/image.py` 中的工具 -- 验证阶段应对不同条件(不同 z 采样)分别生成图片,便于直观对比模型效果 diff --git a/ginka/dataset.py b/ginka/dataset.py index 6d3b834..194a1af 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -4,715 +4,187 @@ import torch import numpy as np from torch.utils.data import Dataset -def _compute_map_labels(map_2d) -> dict: - """ - 从 2D 地图列表(或 numpy 数组)推算结构标签。 - 当 JSON 数据缺少 roomCount / highDegBranchCount / outerWall 字段时调用。 - """ - arr = np.array(map_2d, dtype=np.int64) # [H, W] - H, W = arr.shape - WALL, ENTRY = 1, 5 - - # outerWall:最外圈中 wall+entry 占比 > 90% - border = np.concatenate([arr[0, :], arr[-1, :], arr[1:-1, 0], arr[1:-1, -1]]) - total_b = border.size - outer_wall = int(total_b > 0 and np.sum((border == WALL) | (border == ENTRY)) / total_b > 0.9) - - # roomCount:BFS 统计 floor(0)+resource(3) 连通区域, - # 需满足:总格子 >= 4,外接矩形宽 >= 2 且高 >= 2 - FLOOR_SET = (0, 3) - visited = np.zeros((H, W), dtype=bool) - room_count = 0 - for sy in range(H): - for sx in range(W): - if arr[sy, sx] not in FLOOR_SET or visited[sy, sx]: - continue - queue = [(sy, sx)] - visited[sy, sx] = True - tiles_y, tiles_x = [sy], [sx] - head = 0 - while head < len(queue): - y, x = queue[head]; head += 1 - for dy, dx in ((-1, 0), (1, 0), (0, -1), (0, 1)): - ny, nx = y + dy, x + dx - if 0 <= ny < H and 0 <= nx < W and not visited[ny, nx] and arr[ny, nx] in FLOOR_SET: - visited[ny, nx] = True - queue.append((ny, nx)) - tiles_y.append(ny); tiles_x.append(nx) - if (len(tiles_y) >= 4 - and max(tiles_y) - min(tiles_y) >= 1 - and max(tiles_x) - min(tiles_x) >= 1): - room_count += 1 - - # highDegBranchCount:非 wall 格子中,4 邻域非 wall 邻居 >= 3 的数量 - non_wall = (arr != WALL).astype(np.int32) - padded = np.pad(non_wall, 1, mode='constant', constant_values=0) - nbr_sum = (padded[:-2, 1:-1] + padded[2:, 1:-1] + - padded[1:-1, :-2] + padded[1:-1, 2:]) - high_deg = int(np.sum((non_wall == 1) & (nbr_sum >= 3))) - - return {'outerWall': outer_wall, 'roomCount': room_count, 'highDegBranchCount': high_deg} - - def load_data(path: str): with open(path, 'r', encoding="utf-8") as f: data = json.load(f) data_list = [] for value in data["data"].values(): - # 兼容旧版数据集(缺少结构标签字段) - if 'roomCount' not in value: - labels = _compute_map_labels(value['map']) - value.update(labels) - # symmetry 字段由 __getitem__ 在增强后重新计算,此处不需要从 JSON 读取 data_list.append(value) return data_list -def _compute_symmetry(target_np: np.ndarray) -> tuple: +def compute_symmetry(target_np: np.ndarray) -> tuple: """从 numpy 地图矩阵中直接计算三种对称性,O(H*W)""" sym_h = bool(np.all(target_np == target_np[:, ::-1])) sym_v = bool(np.all(target_np == target_np[::-1, :])) sym_c = bool(np.all(target_np == target_np[::-1, ::-1])) return int(sym_h), int(sym_v), int(sym_c) - -class GinkaVQDataset(Dataset): - """ - 用于 VQ-VAE + MaskGIT 联合训练的多子集数据集。 - - 每次 __getitem__ 按权重随机选取以下四种子集之一: - A (standard): 标准 MaskGIT 随机掩码,随机遮盖部分 tile - B (wall-only): 仅保留 wall(1) + floor(0),其余全部替换为 MASK(6) - C (wall-random): 在 B 基础上,再随机 mask 部分 wall tile - D (wall+entry): 仅保留 wall(1) + floor(0) + entrance(5),其余全部替换为 MASK(6) - - 返回 dict: - raw_map: LongTensor [H*W] 完整原始地图(供 VQ-VAE 编码) - masked_map: LongTensor [H*W] MaskGIT 输入(被 mask 的位置 = 6) - target_map: LongTensor [H*W] CE loss ground truth(等同 raw_map) - subset: str 子集标识,供调试/统计用 - """ - - FLOOR = 0 - WALL = 1 +class GinkaSeperatedDataset(Dataset): + FLOOR = 0 + WALL = 1 + DOOR = 2 + RESOURCE = 3 + MONSTER = 4 ENTRANCE = 5 - MASK_ID = 6 + MASK_ID = 6 def __init__( self, data_path: str, - subset_weights: tuple = (0.5, 0.2, 0.2, 0.1), - wall_mask_ratio: float = 0.3, - room_thresholds: tuple = None, - branch_thresholds: tuple = None, + subset_weights: tuple = (0.5, 0.3, 0.2), + subset2_wall_prob: float = 0.7 ): - """ - Args: - data_path: JSON 数据文件路径 - subset_weights: 子集 (A, B, C, D) 的采样权重,自动归一化 - wall_mask_ratio: Subset C 中额外随机 mask 的 wall tile 比例上限 - (每次从 [0, wall_mask_ratio] 均匀采样实际比例) - room_thresholds: (th1, th2) 房间数量等频分箱阈值;为 None 时自动从当前数据计算(训练集) - branch_thresholds: (th1, th2) 分支数量等频分箱阈值;为 None 时自动从当前数据计算(训练集) - """ self.data = load_data(data_path) - self.wall_mask_ratio = wall_mask_ratio + self.subset2_wall_prob = subset2_wall_prob + total = sum(subset_weights) + self.subset_cumw = [sum(subset_weights[:i+1]) / total for i in range(len(subset_weights))] - # 累积权重,用于快速随机子集选择 - total_w = sum(subset_weights) - normalized = [x / total_w for x in subset_weights] - self.subset_cumw = [sum(normalized[:i + 1]) for i in range(len(normalized))] + n = len(self.data) + rs = sorted(item['roomCount'] for item in self.data) + bs = sorted(item['highDegBranchCount'] for item in self.data) + th1_r, th2_r = rs[n // 3], rs[2 * n // 3] + th1_b, th2_b = bs[n // 3], bs[2 * n // 3] + if th1_r == th2_r: th2_r = th1_r + 1 + if th1_b == th2_b: th2_b = th1_b + 1 + self.room_th = (th1_r, th2_r) + self.branch_th = (th1_b, th2_b) - # ── 两趟扫描:计算等频分箱阈值 ────────────────────────────── - room_counts = [item['roomCount'] for item in self.data] - branch_counts = [item['highDegBranchCount'] for item in self.data] - - if room_thresholds is None: - n = len(room_counts) - rs = sorted(room_counts) - bs = sorted(branch_counts) - th1_r, th2_r = rs[n // 3], rs[2 * n // 3] - th1_b, th2_b = bs[n // 3], bs[2 * n // 3] - # 防止 Medium 等级为空 - if th1_r == th2_r: - th2_r = th1_r + 1 - if th1_b == th2_b: - th2_b = th1_b + 1 - self.room_th = (th1_r, th2_r) - self.branch_th = (th1_b, th2_b) - else: - self.room_th = room_thresholds - self.branch_th = branch_thresholds - - def to_level(v: int, th: tuple) -> int: - return 0 if v < th[0] else (1 if v < th[1] else 2) - - # 回填等级字段 for item in self.data: - item['roomCountLevel'] = to_level(item['roomCount'], self.room_th) - item['branchLevel'] = to_level(item['highDegBranchCount'], self.branch_th) + item['roomCountLevel'] = self.to_level(item['roomCount'], self.room_th) + item['branchLevel'] = self.to_level(item['highDegBranchCount'], self.branch_th) + + def to_level(self, v, th): + return 0 if v < th[0] else (1 if v < th[1] else 2) def __len__(self): return len(self.data) - # ------------------------------------------------------------------ - # 内联随机掩码生成(避免 scipy 的 NumPy 版本兼容问题) - # ------------------------------------------------------------------ - @staticmethod - def _sample_mask_ratio(min_r=0.05, max_r=1.0) -> float: - """用 Beta(2,2) 分布采样掩码比例,集中在中间值。""" - r = np.random.beta(2, 2) - return min_r + (max_r - min_r) * r + def degrade_tile(self, m: np.ndarray, tiles: list) -> np.ndarray: + # 将指定 tile ID 替换为 floor(0),原地修改 + for t in tiles: + m[m == t] = self.FLOOR + return m - @staticmethod - def _random_mask(h: int, w: int) -> np.ndarray: - """纯随机掩码,返回 [H*W] bool。""" - ratio = GinkaVQDataset._sample_mask_ratio() - total = h * w - idx = np.random.choice(total, int(total * ratio), replace=False) - mask = np.zeros(total, dtype=bool) - mask[idx] = True - return mask - - @staticmethod - def _block_mask(h: int, w: int) -> np.ndarray: - """矩形分块随机掩码,返回 [H*W] bool。""" - ratio = GinkaVQDataset._sample_mask_ratio() - max_block = max(2, min(h, w) // 2) - target = int(h * w * ratio) - mask = np.zeros((h, w), dtype=bool) - while mask.sum() < target: - bh = np.random.randint(2, max_block + 1) - bw = np.random.randint(2, max_block + 1) - x = np.random.randint(0, max(1, h - bh + 1)) - y = np.random.randint(0, max(1, w - bw + 1)) - mask[x:x + bh, y:y + bw] = True - return mask.reshape(-1) - - def _std_mask(self, h: int, w: int) -> np.ndarray: - """标准 MaskGIT 掩码:随机选择纯随机或分块策略。""" + def std_mask(self) -> np.ndarray: + # Beta(2,2) 采样掩码比例,50% 随机掩码 / 50% 分块掩码,返回 bool[13, 13] + ratio = float(np.random.beta(2, 2)) * 0.95 + 0.05 if random.random() < 0.5: - return self._random_mask(h, w) - else: - return self._block_mask(h, w) + idx = np.random.choice(169, int(169 * ratio), replace=False) + mask = np.zeros(169, dtype=bool) + mask[idx] = True + return mask.reshape(13, 13) + target = int(169 * ratio) + mask = np.zeros((13, 13), dtype=bool) + while mask.sum() < target: + bh = np.random.randint(2, 7) + bw = np.random.randint(2, 7) + x = np.random.randint(0, 14 - bh) + y = np.random.randint(0, 14 - bw) + mask[x:x + bh, y:y + bw] = True + return mask + + def create_degreaded(self, raw: np.ndarray): + # 阶段一:生成墙壁和入口 + target1 = raw.copy() + self.degrade_tile(target1, [self.DOOR, self.RESOURCE, self.MONSTER]) + inp1 = target1.copy() + + # 阶段二:生成怪物、门,同时也允许生成入口以适配结构 + target2 = raw.copy() + self.degrade_tile(target2, [self.RESOURCE, self.WALL]) + inp2 = raw.copy() + self.degrade_tile(inp2, [self.RESOURCE]) + + # 阶段三:生成资源 + target3 = raw.copy() + self.degrade_tile(target3, [self.WALL, self.DOOR, self.MONSTER, self.ENTRANCE]) + inp3 = raw.copy() + + return target1, inp1, target2, inp2, target3, inp3 - # ------------------------------------------------------------------ + def apply_subset1(self, raw: np.ndarray): + # 子集 1:std_mask 随机掩码 + + target1, inp1, target2, inp2, target3, inp3 = self.create_degreaded(raw) + + enc1 = target1.copy() + enc2 = inp2.copy() + enc3 = raw.copy() + + # stage1:对整图 std_mask + inp1[self.std_mask()] = self.MASK_ID + + # stage2:对 floor+功能元素区域 std_mask + need_mask = np.isin(inp2, [self.FLOOR, self.DOOR, self.MONSTER, self.ENTRANCE]) + inp2[need_mask & self.std_mask()] = self.MASK_ID + + # stage3:对 floor+resource 区域 std_mask + need_mask = np.isin(inp3, [self.FLOOR, self.RESOURCE]) + inp3[need_mask & self.std_mask()] = self.MASK_ID + + return inp1, target1, enc1, inp2, target2, enc2, inp3, target3, enc3 + + def apply_subset2(self, raw: np.ndarray): + # 子集 2:掩码所有内容,墙壁随机掩码,不掩码入口 + target1, inp1, target2, inp2, target3, inp3 = self.create_degreaded(raw) + + enc1 = target1.copy() + enc2 = inp2.copy() + enc3 = raw.copy() + + if np.random.random() < self.subset2_wall_prob: + inp1[self.std_mask()] = self.MASK_ID + need_mask = np.isin(inp2, [self.FLOOR, self.DOOR, self.MONSTER, self.ENTRANCE]) + inp2[need_mask] = self.MASK_ID + need_mask = np.isin(inp3, [self.FLOOR, self.RESOURCE]) + inp3[need_mask] = self.MASK_ID + + return inp1, target1, enc1, inp2, target2, enc2, inp3, target3, enc3 + + def apply_subset3(self, raw: np.ndarray): + # 子集 3:在 2 的基础上掩码入口 + out = self.apply_subset2(raw) + out[0][out[0] == self.ENTRANCE] = self.MASK_ID + return out + + def __getitem__(self, idx): + item = self.data[idx] + map_np = np.array(item['map'], dtype=np.int64) - def _augment(self, arr: np.ndarray) -> np.ndarray: - """随机旋转 / 翻转数据增强,返回新 array。""" if np.random.rand() > 0.5: k = np.random.randint(1, 4) - arr = np.rot90(arr, k).copy() + map_np = np.rot90(map_np, k).copy() if np.random.rand() > 0.5: - arr = np.fliplr(arr).copy() + map_np = np.fliplr(map_np).copy() if np.random.rand() > 0.5: - arr = np.flipud(arr).copy() - return arr + map_np = np.flipud(map_np).copy() - def _choose_subset(self) -> str: r = random.random() if r < self.subset_cumw[0]: - return 'A' + out = self.apply_subset1(map_np) elif r < self.subset_cumw[1]: - return 'B' - elif r < self.subset_cumw[2]: - return 'C' + out = self.apply_subset2(map_np) else: - return 'D' + out = self.apply_subset3(map_np) - def _apply_subset(self, raw: np.ndarray, subset: str) -> np.ndarray: - """ - 根据子集类型生成 masked_map。 - - Args: - raw: [H, W] int64 完整原始地图 - subset: 'A' | 'B' | 'C' | 'D' - - Returns: - [H*W] int64,被遮盖位置値为 MASK_ID(6) - """ - H, W = raw.shape - - if subset == 'A': - # 标准随机 mask:纯随机或分块策略 - mask = self._std_mask(H, W) # [H*W] bool - flat = raw.reshape(-1).copy() - flat[mask] = self.MASK_ID - return flat - - elif subset == 'B': - # 仅保留 wall(1),floor(0) 和其他非墙内容全部 mask - flat = raw.reshape(-1).copy() - keep = (flat == self.WALL) - flat[~keep] = self.MASK_ID - return flat - - elif subset == 'C': - # Subset B + 随机 mask 部分 wall - flat = raw.reshape(-1).copy() - keep = (flat == self.WALL) - flat[~keep] = self.MASK_ID - - wall_idx = np.where(flat == self.WALL)[0] - if len(wall_idx) > 0: - ratio = random.random() * self.wall_mask_ratio - n = max(1, int(len(wall_idx) * ratio)) - chosen = np.random.choice(wall_idx, n, replace=False) - flat[chosen] = self.MASK_ID - return flat - - else: # D - # 仅保留 wall(1) 和 entrance(5),floor(0) 和其他非墙内容全部 mask - flat = raw.reshape(-1).copy() - keep = (flat == self.WALL) | (flat == self.ENTRANCE) - flat[~keep] = self.MASK_ID - - # 随机 mask 部分 wall(模拟真实场景,与子集 C 一致) - wall_idx = np.where(flat == self.WALL)[0] - if len(wall_idx) > 0: - ratio = random.random() * self.wall_mask_ratio - n = max(1, int(len(wall_idx) * ratio)) - chosen = np.random.choice(wall_idx, n, replace=False) - flat[chosen] = self.MASK_ID - return flat - - def __getitem__(self, idx): - item = self.data[idx] - - raw_np = self._augment(np.array(item['map'], dtype=np.int64)) # [H, W] - subset = self._choose_subset() - masked_np = self._apply_subset(raw_np, subset) # [H*W] - raw_flat = raw_np.reshape(-1) # [H*W] - - # 对称性:在增强后重新计算 - sym_h, sym_v, sym_c = _compute_symmetry(raw_np) - cond_sym = sym_h * 4 + sym_v * 2 + sym_c # [0, 7] - - # 其余结构标签:增强不改变拓扑结构,直接读取 - cond_room = item['roomCountLevel'] # 0/1/2 - cond_branch = item['branchLevel'] # 0/1/2 - cond_outer = item['outerWall'] # 0/1 - - struct_cond = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer]) - - raw_t = torch.LongTensor(raw_flat) - return { - "raw_map": raw_t, # VQ-VAE 编码器输入 - "slice1": make_slice(raw_t, {0, 1}), # 通道 1:floor+wall - "slice2": make_slice(raw_t, {0, 1, 2, 4, 5}), # 通道 2:floor+wall+门+怪+入口 - "slice3": raw_t.clone(), # 通道 3:完整地图 - "masked_map": torch.LongTensor(masked_np), # MaskGIT 输入 - "target_map": torch.LongTensor(raw_flat.copy()), # CE loss ground truth - "subset": subset, # 调试/统计用 - "struct_cond": struct_cond, # [4],供模型 Embedding 查表 - } - - -# --------------------------------------------------------------------------- -# make_slice:按保留集合切割地图,其余位置替换为 floor(0) -# --------------------------------------------------------------------------- - -def make_slice(map_flat: torch.Tensor, keep_set: set) -> torch.Tensor: - """ - 从完整地图中只保留 keep_set 中的 tile 类型,其余位置替换为 floor(0)。 - - Args: - map_flat: LongTensor [H*W] 完整地图 tile ID 序列 - keep_set: set of int 需要保留的 tile 类型集合 - - Returns: - LongTensor [H*W] 切片后的地图(非保留 tile 位置值为 0) - """ - out = map_flat.clone() - mask = torch.zeros_like(out, dtype=torch.bool) - for t in keep_set: - mask |= (out == t) - out[~mask] = 0 - return out - - -# --------------------------------------------------------------------------- -# GinkaSplitDataset:三通道分拆预训练专用数据集 -# --------------------------------------------------------------------------- - -class GinkaSplitDataset(Dataset): - """ - 三通道分拆预训练(方案 B)专用数据集。 - - 每个样本只提供完整地图及其三路切片,不做 MaskGIT 掩码处理。 - 切片按累积式设计: - slice1 = floor(0) + wall(1) - slice2 = floor(0) + wall(1) + door(2) + mob(4) + entrance(5) - slice3 = 完整地图(所有 tile) - - 返回 dict: - raw_map: LongTensor [H*W] 完整原始地图 - slice1: LongTensor [H*W] 通道 1 切片(floor+wall) - slice2: LongTensor [H*W] 通道 2 切片(floor+wall+门+怪+入口) - slice3: LongTensor [H*W] 通道 3 切片(完整地图) - """ - - def __init__(self, data_path: str): - self.data = load_data(data_path) - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - item = self.data[idx] - arr = np.array(item['map'], dtype=np.int64) # [H, W] - - # 随机旋转 / 翻转数据增强 - if np.random.rand() > 0.5: - k = np.random.randint(1, 4) - arr = np.rot90(arr, k).copy() - if np.random.rand() > 0.5: - arr = np.fliplr(arr).copy() - if np.random.rand() > 0.5: - arr = np.flipud(arr).copy() - - raw = torch.LongTensor(arr.reshape(-1)) # [H*W] - return { - "raw_map": raw, - "slice1": make_slice(raw, {0, 1}), - "slice2": make_slice(raw, {0, 1, 2, 4, 5}), - "slice3": raw.clone(), - } - - -# --------------------------------------------------------------------------- -# GinkaStageDataset:三阶段级联训练专用数据集 -# --------------------------------------------------------------------------- - -class GinkaStageDataset(Dataset): - """ - 三阶段级联生成训练专用 Dataset。 - - 每个阶段只预测特定类别的 tile,后续阶段以前序阶段输出作为上下文。 - 训练时统一使用 GT 作为前序上下文(teacher forcing),避免误差级联。 - - 阶段划分: - stage=1 结构骨架:预测 floor(0) + wall(1) - stage=2 功能元素:预测 door(2) + monster(4) + entrance(5),以 floor/wall 为上下文 - stage=3 资源放置:预测 resource(3),以完整骨架为上下文 - - 返回 dict: - raw_map: LongTensor [H*W] 完整原始地图(供 VQ-VAE 编码) - vq_slice: LongTensor [H*W] 当前阶段 VQ 编码器的输入切片 - stage_input: LongTensor [H*W] MaskGIT 输入(含上下文 + MASK 位置) - target_map: LongTensor [H*W] CE loss ground truth - loss_mask: BoolTensor [H*W] 只对 True 位置计算损失 - subset: str 子集标识 A/B/C/D - struct_cond: LongTensor [4] [sym, room, branch, outer] - """ - - FLOOR = 0 - WALL = 1 - DOOR = 2 - RESOURCE = 3 - MONSTER = 4 - ENTRANCE = 5 - MASK_ID = 6 - - STAGE1_TARGETS = frozenset({0, 1}) - STAGE2_TARGETS = frozenset({2, 4, 5}) - STAGE3_TARGETS = frozenset({3}) - - # VQ 切片集合:各阶段编码器只"看"与自身相关的 tile - _VQ_KEEP = { - 1: frozenset({0, 1}), - 2: frozenset({0, 1, 2, 4, 5}), - 3: None, # 完整地图 - } - - def __init__( - self, - data_path: str, - stage: int, - subset_weights: tuple = (0.5, 0.2, 0.2, 0.1), - wall_mask_ratio: float = 0.3, - room_thresholds: tuple = None, - branch_thresholds: tuple = None, - ): - """ - Args: - data_path: JSON 数据文件路径 - stage: 生成阶段 1/2/3 - subset_weights: 子集 (A, B, C, D) 的采样权重,自动归一化 - wall_mask_ratio: Subset C 中额外随机 mask 的 wall 比例上限 - room_thresholds: 等频分箱阈值(None 时自动计算) - branch_thresholds: 等频分箱阈值(None 时自动计算) - """ - assert stage in (1, 2, 3), f"stage 必须是 1/2/3,收到 {stage}" - self.stage = stage - self.data = load_data(data_path) - self.wall_mask_ratio = wall_mask_ratio - - total_w = sum(subset_weights) - normalized = [x / total_w for x in subset_weights] - self.subset_cumw = [sum(normalized[:i + 1]) for i in range(len(normalized))] - - room_counts = [item['roomCount'] for item in self.data] - branch_counts = [item['highDegBranchCount'] for item in self.data] - - if room_thresholds is None: - n = len(room_counts) - rs = sorted(room_counts) - bs = sorted(branch_counts) - th1_r, th2_r = rs[n // 3], rs[2 * n // 3] - th1_b, th2_b = bs[n // 3], bs[2 * n // 3] - if th1_r == th2_r: th2_r = th1_r + 1 - if th1_b == th2_b: th2_b = th1_b + 1 - self.room_th = (th1_r, th2_r) - self.branch_th = (th1_b, th2_b) - else: - self.room_th = room_thresholds - self.branch_th = branch_thresholds - - def to_level(v, th): - return 0 if v < th[0] else (1 if v < th[1] else 2) - - for item in self.data: - item['roomCountLevel'] = to_level(item['roomCount'], self.room_th) - item['branchLevel'] = to_level(item['highDegBranchCount'], self.branch_th) - - def __len__(self): - return len(self.data) - - # ------------------------------------------------------------------ - # 掩码辅助(与 GinkaVQDataset 相同逻辑) - # ------------------------------------------------------------------ - @staticmethod - def _sample_mask_ratio(min_r=0.05, max_r=1.0) -> float: - r = np.random.beta(2, 2) - return min_r + (max_r - min_r) * r - - @staticmethod - def _random_mask(h: int, w: int) -> np.ndarray: - ratio = GinkaStageDataset._sample_mask_ratio() - total = h * w - idx = np.random.choice(total, int(total * ratio), replace=False) - mask = np.zeros(total, dtype=bool) - mask[idx] = True - return mask - - @staticmethod - def _block_mask(h: int, w: int) -> np.ndarray: - ratio = GinkaStageDataset._sample_mask_ratio() - max_block = max(2, min(h, w) // 2) - target = int(h * w * ratio) - mask = np.zeros((h, w), dtype=bool) - while mask.sum() < target: - bh = np.random.randint(2, max_block + 1) - bw = np.random.randint(2, max_block + 1) - x = np.random.randint(0, max(1, h - bh + 1)) - y = np.random.randint(0, max(1, w - bw + 1)) - mask[x:x + bh, y:y + bw] = True - return mask.reshape(-1) - - def _std_mask(self, h: int, w: int) -> np.ndarray: - return self._random_mask(h, w) if random.random() < 0.5 else self._block_mask(h, w) - - # ------------------------------------------------------------------ - # 子集选择 - # ------------------------------------------------------------------ - def _choose_subset(self) -> str: - r = random.random() - if r < self.subset_cumw[0]: return 'A' - if r < self.subset_cumw[1]: return 'B' - if r < self.subset_cumw[2]: return 'C' - return 'D' - - # ------------------------------------------------------------------ - # 阶段一:结构骨架(floor + wall) - # ------------------------------------------------------------------ - def _make_stage1(self, raw_flat: np.ndarray, subset: str): - """ - 阶段一:预测 floor/wall,所有非 floor/wall tile 在目标中重映射为 floor。 - 子集决定向模型提供多少 wall 作为上下文条件。 - """ - H = W = 13 - - # 目标:非 floor/wall → floor - target = raw_flat.copy() - target[~np.isin(target, [self.FLOOR, self.WALL])] = self.FLOOR - - inp = target.copy() - - if subset == 'A': - # 标准随机 mask:随机遮盖部分 floor/wall - mask = self._std_mask(H, W) - inp[mask] = self.MASK_ID - - elif subset == 'B': - # 保留全部 wall,MASK floor - inp[inp == self.FLOOR] = self.MASK_ID - - elif subset == 'C': - # 随机保留部分 wall,MASK 其余(含全部 floor) - inp[inp == self.FLOOR] = self.MASK_ID - wall_idx = np.where(inp == self.WALL)[0] - if len(wall_idx) > 0: - ratio = random.random() * self.wall_mask_ratio - n = max(1, int(len(wall_idx) * ratio)) - chosen = np.random.choice(wall_idx, n, replace=False) - inp[chosen] = self.MASK_ID - - else: # D:与 B 相同(阶段一无 entrance 维度) - inp[inp == self.FLOOR] = self.MASK_ID - - loss_mask = (inp == self.MASK_ID) - return inp, target, loss_mask - - # ------------------------------------------------------------------ - # 阶段二:功能元素(door + monster + entrance) - # ------------------------------------------------------------------ - def _make_stage2(self, raw_flat: np.ndarray, subset: str): - """ - 阶段二:以 floor/wall 为上下文,预测 door/monster/entrance。 - resource 在输入与目标中均视为 floor(阶段二不负责资源)。 - 子集决定 wall 上下文的完整程度与 door/monster/entrance 的掩码方式。 - """ - # 目标:resource → floor - target = raw_flat.copy() - target[target == self.RESOURCE] = self.FLOOR - - # 基础输入:resource → floor,功能元素先保留,再按子集处理 - inp = raw_flat.copy() - inp[inp == self.RESOURCE] = self.FLOOR - - if subset == 'A': - # 随机遮盖部分 door/monster/entrance(部分上下文补全) - func_idx = np.where(np.isin(inp, [self.DOOR, self.MONSTER, self.ENTRANCE]))[0] - if len(func_idx) > 0: - ratio = random.random() * 0.8 + 0.2 # 20%~100% - n = max(1, int(len(func_idx) * ratio)) - chosen = np.random.choice(func_idx, n, replace=False) - inp[chosen] = self.MASK_ID - else: - # B/C/D:全部 door/monster/entrance → MASK - inp[np.isin(inp, [self.DOOR, self.MONSTER, self.ENTRANCE])] = self.MASK_ID - - if subset == 'C': - # 额外随机 mask 部分 wall(降低 wall 上下文质量) - wall_idx = np.where(inp == self.WALL)[0] - if len(wall_idx) > 0: - ratio = random.random() * self.wall_mask_ratio - n = max(1, int(len(wall_idx) * ratio)) - chosen = np.random.choice(wall_idx, n, replace=False) - inp[chosen] = self.MASK_ID - - # loss_mask:阶段二只对 door/monster/entrance 原始位置计算损失, - # 不对被额外 mask 的 wall 位置计算(它们在 target 中已知为 wall) - loss_mask = np.isin(raw_flat, [self.DOOR, self.MONSTER, self.ENTRANCE]) - return inp, target, loss_mask - - # ------------------------------------------------------------------ - # 阶段三:资源放置(resource) - # ------------------------------------------------------------------ - def _make_stage3(self, raw_flat: np.ndarray, subset: str): - """ - 阶段三:以完整骨架为上下文,预测 resource 位置。 - 所有 resource 位置在输入中替换为 MASK。 - 子集 A 随机保留部分 resource 作为上下文(部分补全训练), - 其余子集始终 MASK 全部 resource。 - """ - target = raw_flat.copy() - inp = raw_flat.copy() - - if subset == 'A': - # 随机遮盖部分 resource(部分上下文补全) - res_idx = np.where(inp == self.RESOURCE)[0] - if len(res_idx) > 0: - ratio = random.random() * 0.8 + 0.2 # 20%~100% - n = max(1, int(len(res_idx) * ratio)) - chosen = np.random.choice(res_idx, n, replace=False) - inp[chosen] = self.MASK_ID - else: - pass # 无 resource 时无需处理 - else: - # B/C/D:全部 resource → MASK - inp[inp == self.RESOURCE] = self.MASK_ID - - loss_mask = (inp == self.MASK_ID) - return inp, target, loss_mask - - # ------------------------------------------------------------------ - # __getitem__ - # ------------------------------------------------------------------ - def _augment(self, arr: np.ndarray) -> np.ndarray: - if np.random.rand() > 0.5: - k = np.random.randint(1, 4) - arr = np.rot90(arr, k).copy() - if np.random.rand() > 0.5: - arr = np.fliplr(arr).copy() - if np.random.rand() > 0.5: - arr = np.flipud(arr).copy() - return arr - - def __getitem__(self, idx): - item = self.data[idx] - - raw_np = self._augment(np.array(item['map'], dtype=np.int64)) # [H, W] - raw_flat = raw_np.reshape(-1) # [H*W] - subset = self._choose_subset() - - if self.stage == 1: - stage_input_np, target_np, loss_mask_np = self._make_stage1(raw_flat, subset) - elif self.stage == 2: - stage_input_np, target_np, loss_mask_np = self._make_stage2(raw_flat, subset) - else: - stage_input_np, target_np, loss_mask_np = self._make_stage3(raw_flat, subset) - - # 若 loss_mask 全为 False(如地图中无 resource 时的 stage3), - # 退回为全图损失,避免 NaN - if not loss_mask_np.any(): - loss_mask_np = np.ones_like(loss_mask_np) - - # VQ 切片:当前阶段编码器的输入(仅保留相关 tile) - raw_t = torch.LongTensor(raw_flat) - vq_keep = self._VQ_KEEP[self.stage] - if vq_keep is None: - vq_slice = raw_t.clone() - else: - vq_slice = make_slice(raw_t, vq_keep) - - # 结构标签 - sym_h, sym_v, sym_c = _compute_symmetry(raw_np) - cond_sym = sym_h * 4 + sym_v * 2 + sym_c - cond_room = item['roomCountLevel'] + sym_h, sym_v, sym_c = compute_symmetry(map_np) + cond_sym = sym_h * 4 + sym_v * 2 + sym_c + cond_room = item['roomCountLevel'] cond_branch = item['branchLevel'] - cond_outer = item['outerWall'] - struct_cond = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer]) + cond_outer = item['outerWall'] + struct_inject = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer]) return { - "raw_map": raw_t, - "vq_slice": vq_slice, - "stage_input": torch.LongTensor(stage_input_np), - "target_map": torch.LongTensor(target_np), - "loss_mask": torch.BoolTensor(loss_mask_np), - "subset": subset, - "struct_cond": struct_cond, + "input_stage1": torch.LongTensor(out[0]), + "target_stage1": torch.LongTensor(out[1]), + "encoder_stage1": torch.LongTensor(out[2]), + "input_stage2": torch.LongTensor(out[3]), + "target_stage2": torch.LongTensor(out[4]), + "encoder_stage2": torch.LongTensor(out[5]), + "input_stage3": torch.LongTensor(out[6]), + "target_stage3": torch.LongTensor(out[7]), + "encoder_stage3": torch.LongTensor(out[8]), + "struct_inject": struct_inject } - - -if __name__ == "__main__": - import os - data_path = os.path.join(os.path.dirname(__file__), '..', 'data', 'ginka-dataset.json') - ds = GinkaVQDataset(data_path) - print(f"数据集大小: {len(ds)}") - - subset_count = {'A': 0, 'B': 0, 'C': 0, 'D': 0} - for i in range(200): - sample = ds[i % len(ds)] - subset_count[sample['subset']] += 1 - - raw = sample['raw_map'] - masked = sample['masked_map'] - target = sample['target_map'] - print(f"raw_map shape={raw.shape}, dtype={raw.dtype}") - print(f"masked_map shape={masked.shape}, dtype={masked.dtype}") - print(f"target_map shape={target.shape}, dtype={target.dtype}") - print(f"被 mask 的位置数: {(masked == 6).sum().item()} / {masked.numel()}") - print(f"\n200 次采样子集分布: {subset_count}") diff --git a/ginka/maskGIT/model.py b/ginka/maskGIT/model.py index 83c9aba..d5244f3 100644 --- a/ginka/maskGIT/model.py +++ b/ginka/maskGIT/model.py @@ -4,53 +4,28 @@ import torch.nn as nn from ..utils import print_memory from .maskGIT import Transformer -# 结构标签词表大小(最后一个索引为无条件占位符 null) -SYM_VOCAB = 8 # symmetryH/V/C 三位组合 0-6,7 = null -ROOM_VOCAB = 4 # roomCountLevel 0-2,3 = null -BRANCH_VOCAB = 4 # branchLevel 0-2,3 = null -OUTER_VOCAB = 3 # outerWall 0-1,2 = null - +# 结构标签词表大小 +SYM_VOCAB = 8 # symmetryH/V/C 三位组合 0-7 +ROOM_VOCAB = 3 # roomCountLevel 0-2 +BRANCH_VOCAB = 3 # branchLevel 0-2 +OUTER_VOCAB = 2 # outerWall 0-1 class GinkaMaskGIT(nn.Module): - """ - 改造后的 MaskGIT 地图生成模型。 - - 以掩码地图序列和 VQ-VAE 输出的离散隐变量 z 为输入, - 通过 Transformer encoder-decoder 结构预测被遮盖位置的 tile 类别。 - - z 通过 cross-attention 注入到 Transformer decoder, - 作为风格/多样性控制信号,而非结构重建指导。 - """ - def __init__( - self, - num_classes: int = 16, - d_model: int = 192, - d_z: int = 64, - dim_ff: int = 512, - nhead: int = 8, - num_layers: int = 4, - map_size: int = 13 * 13, - z_dropout: float = 0.1, - struct_dropout: float = 0.15, + self, num_classes: int = 16, d_model: int = 192, dim_ff: int = 512, + nhead: int = 8, num_layers: int = 4, map_size: int = 13 * 13, d_z: int = 64 ): """ Args: num_classes: tile 类别数(含 MASK token=15) d_model: Transformer 内部维度 - d_z: VQ-VAE 码字嵌入维度,需与 GinkaVQVAE.d_z 一致 dim_ff: 前馈网络隐层维度 nhead: 注意力头数 num_layers: Transformer 层数 map_size: 地图 token 总数(H * W) - z_dropout: 训练时随机替换 z 为随机码字的概率(提升鲁棒性) - struct_dropout: 训练时以此概率将结构标签替换为 null(无条件占位), - 实现 classifier-free guidance 兼容训练 """ super().__init__() - self.z_dropout = z_dropout - self.struct_dropout_prob = struct_dropout - + # Tile 嵌入 + 位置编码 self.tile_embedding = nn.Embedding(num_classes, d_model) self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model) * 0.02) @@ -67,10 +42,10 @@ class GinkaMaskGIT(nn.Module): # 结构标签嵌入(编码到 d_z 维度) # 注意:结构标签与 VQ 码字语义不同,使用独立投影层避免混用 - self.sym_embed = nn.Embedding(SYM_VOCAB, d_z) - self.room_embed = nn.Embedding(ROOM_VOCAB, d_z) + self.sym_embed = nn.Embedding(SYM_VOCAB, d_z) + self.room_embed = nn.Embedding(ROOM_VOCAB, d_z) self.branch_embed = nn.Embedding(BRANCH_VOCAB, d_z) - self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z) + self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z) self.struct_proj = nn.Sequential( nn.Linear(d_z, d_model * 2), @@ -92,108 +67,75 @@ class GinkaMaskGIT(nn.Module): self, map: torch.Tensor, z: torch.Tensor, - struct_cond: torch.Tensor | None = None, - dropout_struct: bool = False, + struct: torch.Tensor ) -> torch.Tensor: - """ - Args: - map: [B, H*W] 掩码后的地图 token 序列(MASK token = 15) - z: [B, L_total, d_z] VQ-VAE 量化后的离散隐变量; - 方案 B 中 L_total = L1+L2+L3(三路 z 拼接) - struct_cond: [B, 4] 结构标签 LongTensor,顺序为 - [cond_sym, cond_room, cond_branch, cond_outer]; - 为 None 时等价于全 null(无条件模式) - dropout_struct: bool 强制将所有结构标签替换为 null(推理时无条件生成) + # map: [B, H * W] + # z: [B, L * 3, d_z] + # struch: [B, 4] - Returns: - logits: [B, H*W, num_classes] - """ - B = z.shape[0] - - # z dropout:训练时以一定概率将 z 替换为随机均匀噪声, - # 模拟推理时随机采样 z 的分布,避免模型过拟合于精确的 z 语义 - if self.training and self.z_dropout > 0: - mask = torch.rand(B, 1, 1, device=z.device) < self.z_dropout - rand_z = torch.randn_like(z) - z = torch.where(mask, rand_z, z) - - # 结构标签嵌入 - # struct_cond 为 None 或 dropout_struct=True 时,全部使用 null 索引 - if struct_cond is None or dropout_struct: - sym_idx = torch.full((B,), SYM_VOCAB - 1, dtype=torch.long, device=z.device) - room_idx = torch.full((B,), ROOM_VOCAB - 1, dtype=torch.long, device=z.device) - branch_idx = torch.full((B,), BRANCH_VOCAB - 1, dtype=torch.long, device=z.device) - outer_idx = torch.full((B,), OUTER_VOCAB - 1, dtype=torch.long, device=z.device) - else: - sc = struct_cond.to(z.device) - sym_idx, room_idx, branch_idx, outer_idx = sc[:, 0], sc[:, 1], sc[:, 2], sc[:, 3] - - # 训练时对各标签独立做 struct dropout - if self.training and self.struct_dropout_prob > 0: - def _drop(idx, null_val): - drop_mask = torch.rand(B, device=z.device) < self.struct_dropout_prob - return torch.where(drop_mask, torch.full_like(idx, null_val), idx) - sym_idx = _drop(sym_idx, SYM_VOCAB - 1) - room_idx = _drop(room_idx, ROOM_VOCAB - 1) - branch_idx = _drop(branch_idx, BRANCH_VOCAB - 1) - outer_idx = _drop(outer_idx, OUTER_VOCAB - 1) + sym_idx = struct[:, 0] + room_idx = struct[:, 1] + branch_idx = struct[:, 2] + outer_idx = struct[:, 3] # 嵌入结构标签到 d_z 维度,拼接到 z 序列末尾 - e_sym = self.sym_embed(sym_idx).unsqueeze(1) # [B, 1, d_z] - e_room = self.room_embed(room_idx).unsqueeze(1) # [B, 1, d_z] + e_sym = self.sym_embed(sym_idx).unsqueeze(1) # [B, 1, d_z] + e_room = self.room_embed(room_idx).unsqueeze(1) # [B, 1, d_z] e_branch = self.branch_embed(branch_idx).unsqueeze(1) # [B, 1, d_z] - e_outer = self.outer_embed(outer_idx).unsqueeze(1) # [B, 1, d_z] + e_outer = self.outer_embed(outer_idx).unsqueeze(1) # [B, 1, d_z] - struct_seq = torch.cat([e_sym, e_room, e_branch, e_outer], dim=1) # [B, 4, d_z] + struct_seq = torch.cat([e_sym, e_room, e_branch, e_outer], dim=1) # [B, 4, d_z] # VQ 码字与结构标签语义不同,使用各自独立的投影层后再拼接 - z_mem_vq = self.z_proj(z) # [B, L, d_model] - z_mem_struct = self.struct_proj(struct_seq) # [B, 4, d_model] - z_mem = torch.cat([z_mem_vq, z_mem_struct], dim=1) # [B, L+4, d_model] + z_mem_vq = self.z_proj(z) # [B, L, d_model] + z_mem_struct = self.struct_proj(struct_seq) # [B, 4, d_model] + z_mem = torch.cat([z_mem_vq, z_mem_struct], dim=1) # [B, L * 3 + 4, d_model] # tile embedding + 位置编码 - x = self.tile_embedding(map) # [B, H*W, d_model] - x = x + self.pos_embedding # [B, H*W, d_model] + x = self.tile_embedding(map) # [B, H * W, d_model] + x = x + self.pos_embedding # [B, H * W, d_model] # Transformer:encoder 做 map 自注意力,decoder cross-attend z+struct - x = self.transformer(x, memory=z_mem) # [B, H*W, d_model] + x = self.transformer(x, memory=z_mem) # [B, H * W, d_model] - logits = self.output_fc(x) # [B, H*W, num_classes] + logits = self.output_fc(x) # [B, H * W, num_classes] return logits - if __name__ == "__main__": - device = torch.device("cpu") + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + map_input = torch.randint(0, 7, (4, 13 * 13)).to(device) # [4, 169] + z_input = torch.randn(4, 2, 64).to(device) # [4, 2, 64] + struct_input = torch.tensor([ + [3, 1, 0, 1], + [0, 2, 1, 0], + [5, 1, 2, 1], + [1, 0, 1, 0], + ], dtype=torch.long).to(device) # [4, 4] model = GinkaMaskGIT( - num_classes=16, + num_classes=7, d_model=192, d_z=64, - dim_ff=512, + dim_ff=2048, nhead=8, - num_layers=4, + num_layers=6, map_size=13 * 13, ).to(device) - total_params = sum(p.numel() for p in model.parameters()) - print(f"总参数量: {total_params:,} ({total_params / 1e6:.3f}M)") - for name, module in model.named_children(): - n = sum(p.numel() for p in module.parameters()) - print(f" {name}: {n:,}") + print_memory(device, "初始化后") - map_input = torch.randint(0, 16, (4, 13 * 13)).to(device) # [B=4, 169] - z_input = torch.randn(4, 2, 64).to(device) # [B=4, L=2, d_z=64] - struct_input = torch.tensor([[3, 1, 0, 1], - [0, 2, 1, 0], - [7, 3, 3, 2], - [1, 0, 2, 1]], dtype=torch.long).to(device) # [B=4, 4] - - model.train() - logits = model(map_input, z_input, struct_cond=struct_input) - print(f"\nlogits shape: {logits.shape}") # [4, 169, 16] - - # 无条件模式测试 - logits_uncond = model(map_input, z_input, struct_cond=None) - print(f"logits_uncond shape: {logits_uncond.shape}") # [4, 169, 16] + start = time.perf_counter() + logits = model(map_input, z_input, struct_input) + end = time.perf_counter() print_memory(device, "前向传播后") + + print(f"推理耗时: {end - start:.4f}s") + print(f"输出形状: logits={logits.shape}") + print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}") + print(f"Z Projection parameters: {sum(p.numel() for p in model.z_proj.parameters())}") + print(f"Struct Projection parameters: {sum(p.numel() for p in model.struct_proj.parameters())}") + print(f"Transformer parameters: {sum(p.numel() for p in model.transformer.parameters())}") + print(f"Output FC parameters: {sum(p.numel() for p in model.output_fc.parameters())}") + print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") diff --git a/ginka/train_seperated.py b/ginka/train_seperated.py new file mode 100644 index 0000000..b5e7609 --- /dev/null +++ b/ginka/train_seperated.py @@ -0,0 +1,666 @@ +import argparse +import math +import os +import sys +import random +from datetime import datetime + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +import torch.optim as optim +from tqdm import tqdm +from torch.utils.data import DataLoader + +from .vqvae.quantize import VectorQuantizer +from .vqvae.model import GinkaVQVAE +from .maskGIT.model import GinkaMaskGIT +from .dataset import GinkaSeperatedDataset +from shared.image import matrix_to_image_cv + +# 三阶段级联地图生成训练脚本 +# +# 整体架构: +# VQ-VAE(三组独立编码器 vq1/vq2/vq3)将三阶段地图上下文分别编码为离散潜变量, +# 再由共用 VectorQuantizer 统一量化为 z_q; +# 三个独立 MaskGIT(mg1/mg2/mg3)分别以 z_q 和 struct_inject 为条件, +# 逐阶段迭代解码地图图块序列。 +# +# 三阶段生成目标: +# stage1 → floor / wall(地图骨架) +# stage2 → door / monster / entrance(功能性实体) +# stage3 → resource(资源点) + +# 图块 ID 定义: +# 0. 空地 1. 墙壁 2. 门 3. 资源 4. 怪物 5. 入口 6. 掩码(MASK_TOKEN) + +# 共用 VQ-VAE 超参 +# 三组编码器(vq1/vq2/vq3)共享相同超参,分别对三阶段地图上下文独立编码 +VQ_L = 2 # 码字序列长度(每个编码器输出 L 个码字,量化后合并为 L*3) +VQ_K = 8 # codebook 大小(离散码本条目数) +VQ_D_Z = 64 # 码字维度 +VQ_BETA = 0.5 # commit loss 权重(防止编码器输出漂离 codebook) +VQ_GAMMA = 0.0 # entropy loss 权重(当前未启用) +VQ_LAYERS = 3 # VQ-VAE Transformer 层数 +VQ_DIM_FF = 512 # VQ-VAE 前馈网络隐层维度 +VQ_D_MODEL = 64 # VQ-VAE Transformer 模型维度 +VQ_NHEAD = 8 # VQ-VAE 多头注意力头数 + +# 第一阶段 MaskGIT 超参 +STAGE1_MG_DMODEL = 192 +STAGE1_MG_NHEAD = 8 +STAGE1_MG_NUM_LAYERS = 6 +STAGE1_MG_DIM_FF = 1024 + +# 第二阶段 MaskGIT 超参 +STAGE2_MG_DMODEL = 192 +STAGE2_MG_NHEAD = 8 +STAGE2_MG_NUM_LAYERS = 6 +STAGE2_MG_DIM_FF = 1024 + +# 第三阶段 MaskGIT 超参 +STAGE3_MG_DMODEL = 192 +STAGE3_MG_NHEAD = 8 +STAGE3_MG_NUM_LAYERS = 6 +STAGE3_MG_DIM_FF = 1024 + +# 三阶段 Focal Loss 损失权重(可调节各阶段对总损失的贡献比例) +STAGE1_FOCAL_WEIGHT = 1.0 +STAGE2_FOCAL_WEIGHT = 1.0 +STAGE3_FOCAL_WEIGHT = 1.0 + +# 各阶段 VQ commit loss 权重(当前未单独使用,统一由 VQ_BETA 控制) +STAGE1_VQ_WEIGHT = 0.5 +STAGE2_VQ_WEIGHT = 0.5 +STAGE3_VQ_WEIGHT = 0.5 + +# 全局参数 +NUM_CLASSES = 7 # 图块类型数 +MASK_TOKEN = 6 # 掩码图块 +MAP_W = 13 # 地图宽度 +MAP_H = 13 # 地图高度 +MAP_SIZE = MAP_W * MAP_H # 地图大小 +GENERATE_STEP = 18 # MaskGIT 采样步数 +SUBSET2_WALL_PROB = 0.7 # 子集2 进行墙壁掩码的概率 +SUBSET_WEIGHTS = (0.5, 0.3, 0.2) # 每个子集的概率 + +MG_Z_DROPOUT = 0.1 # z 隐变量 Dropout 概率 +MG_STRUCT_DROPOUT = 0.1 # 结构参量 Dropout 概率 + +# 损失参数 +FOCAL_GAMMA = 2.0 # Focal Loss 参数 +VQ_BETA = 0.5 # 承诺损失权重 + +# 训练超参 +BATCH_SIZE = 64 # 每批样本数 +LR = 1e-4 # AdamW 初始学习率 +MIN_LR = 1e-6 # 余弦退火最低学习率 +WEIGHT_DECAY = 1e-4 # L2 正则化系数 +EPOCHS = 400 # 总训练轮数 +CHECKPOINT = 20 # 每隔多少 epoch 保存检查点并执行验证 + +device = torch.device( + "cuda:1" if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() + else "cpu" +) + +disable_tqdm = not sys.stdout.isatty() + +def _str2bool(v: str): + if isinstance(v, bool): return v + if v.lower() in ('true', '1', 'yes'): return True + if v.lower() in ('false', '0', 'no'): return False + raise argparse.ArgumentTypeError(f"布尔值应为 True/False,收到: {v!r}") + +def parse_arguments(): + parser = argparse.ArgumentParser(description="三阶段级联训练") + parser.add_argument("--resume", type=_str2bool, default=False) + parser.add_argument("--state", type=str, default="", help="续训时检查点路径") + parser.add_argument("--train", type=str, default="ginka-dataset.json") + parser.add_argument("--validate", type=str, default="ginka-eval.json") + parser.add_argument("--load_optim", type=_str2bool, default=True) + return parser.parse_args() + +def build_model(device: torch.device): + # 三组 VQ-VAE 编码器:各自独立编码一个阶段的地图上下文(encoder_stage1/2/3) + # 输出形状均为 [B, L, d_z],拼接后送入共用 quantizer + vq_kwargs = dict( + num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_model=VQ_D_MODEL, + nhead=VQ_NHEAD, num_layers=VQ_LAYERS, dim_ff=VQ_DIM_FF, map_size=MAP_SIZE + ) + vq1 = GinkaVQVAE(**vq_kwargs).to(device) # 编码 stage1 上下文(floor/wall) + vq2 = GinkaVQVAE(**vq_kwargs).to(device) # 编码 stage2 上下文(door/monster/entrance) + vq3 = GinkaVQVAE(**vq_kwargs).to(device) # 编码 stage3 上下文(resource) + + # 三个独立 MaskGIT 解码器,均接收完整的三阶段 z_q 作为条件 + mg1 = GinkaMaskGIT( + num_classes=NUM_CLASSES, d_model=STAGE1_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE1_MG_DIM_FF, + nhead=STAGE1_MG_NHEAD, num_layers=STAGE1_MG_NUM_LAYERS, map_size=MAP_SIZE + ).to(device) + mg2 = GinkaMaskGIT( + num_classes=NUM_CLASSES, d_model=STAGE2_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE2_MG_DIM_FF, + nhead=STAGE2_MG_NHEAD, num_layers=STAGE2_MG_NUM_LAYERS, map_size=MAP_SIZE + ).to(device) + mg3 = GinkaMaskGIT( + num_classes=NUM_CLASSES, d_model=STAGE3_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE3_MG_DIM_FF, + nhead=STAGE3_MG_NHEAD, num_layers=STAGE3_MG_NUM_LAYERS, map_size=MAP_SIZE + ).to(device) + + # 六个模型参数合并到同一优化器,端到端联合训练 + all_params = ( + list(vq1.parameters()) + list(vq2.parameters()) + list(vq3.parameters()) + + list(mg1.parameters()) + list(mg2.parameters()) + list(mg3.parameters()) + ) + optimizer = optim.AdamW(all_params, lr=LR, weight_decay=1e-4) + # 余弦退火:从 LR 线性衰减至 MIN_LR,周期为全部训练轮数 + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=MIN_LR) + + # 共用 VectorQuantizer:不参与梯度更新,仅在前向时做码本查表 + quantizer = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z) + + return vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler + +def focal_loss(logits, target): + # logits: [B, L, C],需转为 [B, C, L] 以匹配 cross_entropy 期望格式 + ce = F.cross_entropy(logits.permute(0, 2, 1), target, reduction='none') + pt = torch.exp(-ce) # pt = 模型对正确类的预测概率 + # Focal Loss:对高置信度样本降低权重,让模型更专注于难样本 + focal = ((1 - pt) ** FOCAL_GAMMA) * ce + return focal.mean() + +def random_struct(device: torch.device) -> torch.Tensor: + # 随机采样一组结构参量,用于无条件自由生成 + # struct_inject 格式:[cond_sym(0-7), cond_room(0-2), cond_branch(0-2), cond_outer(0-1)] + cond_sym = random.randint(0, 7) # 地图对称类型 + cond_room = random.randint(0, 2) # 房间数量档位 + cond_branch = random.randint(0, 2) # 分支复杂度档位 + cond_outer = random.randint(0, 1) # 是否有外围走廊 + return torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer]).unsqueeze(0).to(device) + +def maskgit_sample( + model: torch.nn.Module, inp: torch.Tensor, z: torch.Tensor, + struct: torch.Tensor, steps: int +) -> np.ndarray: + current = inp.clone() + + # 迭代去掩码:每步根据置信度分数重新决定掩码位置 + for step in range(steps): + logits = model(current, z, struct) + probs = F.softmax(logits, dim=-1) + + dist = torch.distributions.Categorical(probs) + sampled = dist.sample() + + confidences = torch.gather(probs, -1, sampled.unsqueeze(-1)).squeeze(-1) + + # 余弦退火调度:随步数推进,保留掩码的位置数量递减至 0 + ratio = math.cos(((step + 1) / steps) * math.pi / 2) + num_to_mask = math.floor(ratio * MAP_SIZE) + + # 输入中已有的非掩码位(来自上一阶段)保持不变 + fixed_mask = (current[0] != MASK_TOKEN) + sampled[0, fixed_mask] = current[0, fixed_mask] + confidences[0, fixed_mask] = 1.0 + + if num_to_mask > 0: + # 将置信度最低的位重新掩码,留待下一步重新预测 + _, mask_indices = torch.topk(confidences[0], k=num_to_mask, largest=False) + sampled[0].scatter_(0, mask_indices, MASK_TOKEN) + + current = sampled + + if (current[0] == MASK_TOKEN).sum() == 0: + break + + # 兜底:若仍有残余掩码位(理论上不应发生),用 argmax 确定性填充 + still_masked = (current[0] == MASK_TOKEN) + if still_masked.any(): + logits = model(current, z, struct) + current[0, still_masked] = torch.argmax(logits[0, still_masked], dim=-1) + + return current[0].cpu().numpy().reshape(MAP_H, MAP_W) + +def full_generate_random_z( + input: torch.Tensor, + struct: torch.Tensor, + models: list[torch.nn.Module], + device: torch.device +) -> tuple: + vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models + + with torch.no_grad(): + z = quantizer.sample(1, VQ_L, device) + + # stage1:生成 floor/wall 骨架 + pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP) + inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE) + inp2[inp2 == 0] = MASK_TOKEN # 空地位交由 stage2 填充 + + # stage2:在骨架上生成 door/monster/entrance,非零结果覆盖合并 + pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP) + merged12 = pred1_np.copy() + merged12[pred2_np != 0] = pred2_np[pred2_np != 0] + inp3 = torch.tensor(merged12.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE) + inp3[inp3 == 0] = MASK_TOKEN + + # stage3:填充 resource + pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP) + merged123 = merged12.copy() + merged123[pred3_np != 0] = pred3_np[pred3_np != 0] + + return pred1_np, merged12, merged123 + +def full_generate_specific_z( + input: torch.Tensor, + z: torch.Tensor, + struct: torch.Tensor, + models: list[torch.nn.Module], + device: torch.device +) -> tuple: + vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models + + with torch.no_grad(): + # 与 full_generate_random_z 相同的三阶段级联,但使用给定的 z + pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP) + inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE) + inp2[inp2 == 0] = MASK_TOKEN + + pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP) + merged12 = pred1_np.copy() + merged12[pred2_np != 0] = pred2_np[pred2_np != 0] + inp3 = torch.tensor(merged12.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE) + inp3[inp3 == 0] = MASK_TOKEN + + pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP) + merged123 = merged12.copy() + merged123[pred3_np != 0] = pred3_np[pred3_np != 0] + + return pred1_np, merged12, merged123 + +# 验证可视化 part1:3×3 网格;行1=编码器输入,行2=掩码输入,行3=三阶段预测(合并) +def visualize_part1(batch, logits1, logits2, logits3, tile_dict): + SEP = 3 + TILE_SIZE = 32 + img_h = MAP_H * TILE_SIZE + img_w = MAP_W * TILE_SIZE + + def to_img(mat): + return matrix_to_image_cv(mat, tile_dict, TILE_SIZE) + + pred1 = torch.argmax(logits1[0], dim=-1).cpu().numpy().reshape(MAP_H, MAP_W) + pred2 = torch.argmax(logits2[0], dim=-1).cpu().numpy().reshape(MAP_H, MAP_W) + pred3 = torch.argmax(logits3[0], dim=-1).cpu().numpy().reshape(MAP_H, MAP_W) + + pred3_merged = pred1.copy() + pred3_merged[pred2 != 0] = pred2[pred2 != 0] + pred3_merged[pred3 != 0] = pred3[pred3 != 0] + + enc1_np = batch["encoder_stage1"][0].numpy().reshape(MAP_H, MAP_W) + enc2_np = batch["encoder_stage2"][0].numpy().reshape(MAP_H, MAP_W) + enc3_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W) + inp1_np = batch["input_stage1"][0].numpy().reshape(MAP_H, MAP_W) + inp2_np = batch["input_stage2"][0].numpy().reshape(MAP_H, MAP_W) + inp3_np = batch["input_stage3"][0].numpy().reshape(MAP_H, MAP_W) + + rows = [ + [to_img(enc1_np), to_img(enc2_np), to_img(enc3_np)], + [to_img(inp1_np), to_img(inp2_np), to_img(inp3_np)], + [to_img(pred1), to_img(pred2), to_img(pred3_merged)], + ] + grid = np.ones((3 * img_h + 4 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255 + for r, row in enumerate(rows): + for c, img in enumerate(row): + y = SEP + r * (img_h + SEP) + x = SEP + c * (img_w + SEP) + grid[y:y + img_h, x:x + img_w] = img + return grid + +# 验证可视化 part2:行1=真实地图三阶段,行2=stage1 输入与使用真实 z 自回归生成的各阶段结果 +def visualize_part2(batch, z_q, models, device, tile_dict): + SEP = 3 + TILE_SIZE = 32 + img_h = MAP_H * TILE_SIZE + img_w = MAP_W * TILE_SIZE + + def to_img(mat): + return matrix_to_image_cv(mat, tile_dict, TILE_SIZE) + + inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE) + struct_t = batch["struct_inject"][0:1].to(device) + auto_pred1_np, auto_merged12, auto_merged123 = full_generate_specific_z( + inp1_t, z_q[0:1], struct_t, models, device + ) + + enc1_np = batch["encoder_stage1"][0].numpy().reshape(MAP_H, MAP_W) + enc2_np = batch["encoder_stage2"][0].numpy().reshape(MAP_H, MAP_W) + enc3_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W) + inp1_np = batch["input_stage1"][0].numpy().reshape(MAP_H, MAP_W) + + rows = [ + [to_img(enc1_np), to_img(enc2_np), to_img(enc3_np)], + [to_img(inp1_np), to_img(auto_pred1_np), to_img(auto_merged12), to_img(auto_merged123)], + ] + grid = np.ones((2 * img_h + 3 * SEP, 4 * img_w + 5 * SEP, 3), dtype=np.uint8) * 255 + for r, row in enumerate(rows): + for c, img in enumerate(row): + y = SEP + r * (img_h + SEP) + x = SEP + c * (img_w + SEP) + grid[y:y + img_h, x:x + img_w] = img + return grid + +# 验证可视化 part3:2×3 网格;行1=参考输入+相同 struct 随机 z 生成,行2=随机 struct 生成 +def visualize_part3(batch, models, device, tile_dict): + SEP = 3 + TILE_SIZE = 32 + img_h = MAP_H * TILE_SIZE + img_w = MAP_W * TILE_SIZE + + def to_img(mat): + return matrix_to_image_cv(mat, tile_dict, TILE_SIZE) + + inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE) + struct_ref = batch["struct_inject"][0:1].to(device) + inp1_np = batch["input_stage1"][0].numpy().reshape(MAP_H, MAP_W) + + row1 = [to_img(inp1_np)] + for _ in range(2): + _, _, merged123 = full_generate_random_z(inp1_t, struct_ref, models, device) + row1.append(to_img(merged123)) + + row2 = [] + for _ in range(3): + _, _, merged123 = full_generate_random_z(inp1_t, random_struct(device), models, device) + row2.append(to_img(merged123)) + + rows = [row1, row2] + grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255 + for r, row in enumerate(rows): + for c, img in enumerate(row): + y = SEP + r * (img_h + SEP) + x = SEP + c * (img_w + SEP) + grid[y:y + img_h, x:x + img_w] = img + return grid + +# 验证可视化 part4:2×3 网格;以少量随机墙壁作为种子,纯随机 struct+z 自由生成 +def visualize_part4(models, device, tile_dict): + SEP = 3 + TILE_SIZE = 32 + img_h = MAP_H * TILE_SIZE + img_w = MAP_W * TILE_SIZE + + def to_img(mat): + return matrix_to_image_cv(mat, tile_dict, TILE_SIZE) + + n_walls = random.randint(math.floor(MAP_SIZE * 0.02), math.floor(MAP_SIZE * 0.06)) + seed = torch.full((1, MAP_SIZE), MASK_TOKEN, dtype=torch.long, device=device) + wall_pos = torch.randperm(MAP_SIZE, device=device)[:n_walls] + seed[0, wall_pos] = 1 + seed_np = seed[0].cpu().numpy().reshape(MAP_H, MAP_W) + + results = [] + for _ in range(5): + _, _, merged123 = full_generate_random_z(seed, random_struct(device), models, device) + results.append(to_img(merged123)) + + row1 = [to_img(seed_np)] + results[:2] + row2 = results[2:] + rows = [row1, row2] + grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255 + for r, row in enumerate(rows): + for c, img in enumerate(row): + y = SEP + r * (img_h + SEP) + x = SEP + c * (img_w + SEP) + grid[y:y + img_h, x:x + img_w] = img + return grid + +def visualize_validate( + batch, logits1, logits2, logits3, z_q, + models: list[torch.nn.Module], device: torch.device, tile_dict, epoch: int, batch_idx: int +): + save_dir = f"result/seperated/e{epoch}" + os.makedirs(save_dir, exist_ok=True) + cv2.imwrite(f"{save_dir}/val{batch_idx}.png", visualize_part1(batch, logits1, logits2, logits3, tile_dict)) + cv2.imwrite(f"{save_dir}/full{batch_idx}.png", visualize_part2(batch, z_q, models, device, tile_dict)) + cv2.imwrite(f"{save_dir}/rand{batch_idx}.png", visualize_part3(batch, models, device, tile_dict)) + +def validate(dataloader: DataLoader, models: list[torch.nn.Module], device: torch.device, tile_dict, epoch: int): + vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models + + # 切换为推理模式(关闭 Dropout / BatchNorm 统计更新) + for m in [vq1, vq2, vq3, mg1, mg2, mg3]: + m.eval() + + # 累计各阶段损失(跨所有 batch 求和,最终除以 batch 数得到均值) + loss1_total = torch.Tensor([0]).to(device) + loss2_total = torch.Tensor([0]).to(device) + loss3_total = torch.Tensor([0]).to(device) + commit_total = torch.Tensor([0]).to(device) + + idx = 0 + + with torch.no_grad(): + for batch in tqdm(dataloader, leave=False, desc="Validate Progress", disable=disable_tqdm): + # 三阶段各自的掩码输入、预测目标和 VQ 编码器输入 + inp1 = batch["input_stage1"].to(device).reshape(-1, MAP_SIZE) + target1 = batch["target_stage1"].to(device).reshape(-1, MAP_SIZE) + enc1 = batch["encoder_stage1"].to(device).reshape(-1, MAP_SIZE) + + inp2 = batch["input_stage2"].to(device).reshape(-1, MAP_SIZE) + target2 = batch["target_stage2"].to(device).reshape(-1, MAP_SIZE) + enc2 = batch["encoder_stage2"].to(device).reshape(-1, MAP_SIZE) + + inp3 = batch["input_stage3"].to(device).reshape(-1, MAP_SIZE) + target3 = batch["target_stage3"].to(device).reshape(-1, MAP_SIZE) + enc3 = batch["encoder_stage3"].to(device).reshape(-1, MAP_SIZE) + + struct = batch["struct_inject"].to(device) + + # VQ 编码:各阶段独立编码后拼接、量化 + z_e1 = vq1(enc1) # [B, L, d_z] + z_e2 = vq2(enc2) + z_e3 = vq3(enc3) + + z_e_all = torch.cat([z_e1, z_e2, z_e3], dim=1) # [B, L*3, d_z] + z_q, _, commit_loss = quantizer(z_e_all) # [B, L*3, d_z] + + # 三阶段 MaskGIT 推理(均以完整 z_q 和 struct 为条件) + logits1 = mg1(inp1, z_q, struct) + logits2 = mg2(inp2, z_q, struct) + logits3 = mg3(inp3, z_q, struct) + + loss1_total += focal_loss(logits1, target1) + loss2_total += focal_loss(logits2, target2) + loss3_total += focal_loss(logits3, target3) + commit_total += commit_loss + + # 每个 batch 生成三种可视化图(val/full/rand) + visualize_validate(batch, logits1, logits2, logits3, z_q, models, device, tile_dict, epoch, idx) + idx += 1 + + # 每个 epoch 额外生成一张无条件自由生成图(不依赖任何 batch 样本) + save_dir = f"result/seperated/e{epoch}" + os.makedirs(save_dir, exist_ok=True) + cv2.imwrite(f"{save_dir}/free.png", visualize_part4(models, device, tile_dict)) + + # 恢复训练模式 + for m in [vq1, vq2, vq3, mg1, mg2, mg3]: + m.train() + + return loss1_total, loss2_total, loss3_total, commit_total + +def train(device: torch.device): + args = parse_arguments() + + models = build_model(device) + vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models + + start_epoch = 0 + + if args.resume: + # 从指定检查点恢复:加载所有模型权重及训练状态 + ckpt = torch.load(args.state, map_location=device) + vq1.load_state_dict(ckpt["vq1"]) + vq2.load_state_dict(ckpt["vq2"]) + vq3.load_state_dict(ckpt["vq3"]) + mg1.load_state_dict(ckpt["mg1"]) + mg2.load_state_dict(ckpt["mg2"]) + mg3.load_state_dict(ckpt["mg3"]) + quantizer.load_state_dict(ckpt["quantizer"]) + # load_optim=False 时可跳过优化器/调度器恢复(适合调整学习率后继续训练) + if args.load_optim and "optimizer" in ckpt: + optimizer.load_state_dict(ckpt["optimizer"]) + if args.load_optim and "scheduler" in ckpt: + scheduler.load_state_dict(ckpt["scheduler"]) + start_epoch = ckpt.get("epoch", 0) # 从上次保存的 epoch 继续 + tqdm.write(f"Resumed from epoch {start_epoch}: {args.state}") + + os.makedirs("result/seperated", exist_ok=True) + + dataset = GinkaSeperatedDataset( + args.train, subset_weights=SUBSET_WEIGHTS, subset2_wall_prob=SUBSET2_WALL_PROB + ) + dataloader = DataLoader( + dataset, batch_size=BATCH_SIZE, shuffle=True + ) + + dataset_val = GinkaSeperatedDataset( + args.validate, subset_weights=SUBSET_WEIGHTS, subset2_wall_prob=SUBSET2_WALL_PROB + ) + dataloader_val = DataLoader( + dataset_val, batch_size=min(BATCH_SIZE, len(dataset_val) // 8), shuffle=True + ) + + # 预加载图块图像,键为文件名(不含扩展名),用于可视化时将 ID 映射为像素图 + tile_dict = {} + for f in os.listdir("tiles"): + name = os.path.splitext(f)[0] + img = cv2.imread(f"tiles/{f}", cv2.IMREAD_UNCHANGED) + if img is not None: + tile_dict[name] = img + + for epoch in tqdm(range(start_epoch, EPOCHS), desc="Seperated Training", disable=disable_tqdm): + loss_total = torch.Tensor([0]).to(device) + loss1_total = torch.Tensor([0]).to(device) + loss2_total = torch.Tensor([0]).to(device) + loss3_total = torch.Tensor([0]).to(device) + commit_total = torch.Tensor([0]).to(device) + + for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): + # 三阶段各自的掩码输入序列、预测目标和编码器上下文 + inp1 = batch["input_stage1"].to(device).reshape(-1, MAP_SIZE) + target1 = batch["target_stage1"].to(device).reshape(-1, MAP_SIZE) + enc1 = batch["encoder_stage1"].to(device).reshape(-1, MAP_SIZE) + + inp2 = batch["input_stage2"].to(device).reshape(-1, MAP_SIZE) + target2 = batch["target_stage2"].to(device).reshape(-1, MAP_SIZE) + enc2 = batch["encoder_stage2"].to(device).reshape(-1, MAP_SIZE) + + inp3 = batch["input_stage3"].to(device).reshape(-1, MAP_SIZE) + target3 = batch["target_stage3"].to(device).reshape(-1, MAP_SIZE) + enc3 = batch["encoder_stage3"].to(device).reshape(-1, MAP_SIZE) + + # 结构条件向量:[cond_sym, cond_room, cond_branch, cond_outer] + struct = batch["struct_inject"].to(device) + + optimizer.zero_grad() + + # VQ 编码:各阶段编码器分别处理各自上下文切片 + z_e1 = vq1(enc1) # [B, L, d_z] + z_e2 = vq2(enc2) + z_e3 = vq3(enc3) + + # 合并三阶段编码后量化 + z_e_all = torch.cat([z_e1, z_e2, z_e3], dim=1) # [B, L*3, d_z] + z_q, _, commit_loss = quantizer(z_e_all) # [B, L*3, d_z] + + # 三阶段 MaskGIT 前向(均接收完整三阶段 z_q) + logits1 = mg1(inp1, z_q, struct) + logits2 = mg2(inp2, z_q, struct) + logits3 = mg3(inp3, z_q, struct) + + # 三阶段 Focal Loss + VQ commit loss 加权求和 + loss1 = focal_loss(logits1, target1) + loss2 = focal_loss(logits2, target2) + loss3 = focal_loss(logits3, target3) + loss1_weighted = STAGE1_FOCAL_WEIGHT * loss1 + loss2_weighted = STAGE2_FOCAL_WEIGHT * loss2 + loss3_weighted = STAGE3_FOCAL_WEIGHT * loss3 + commit_weighted = VQ_BETA * commit_loss + loss = loss1_weighted + loss2_weighted + loss3_weighted + commit_weighted + + loss.backward() + optimizer.step() + + # detach 后累加,避免保留计算图占用显存 + loss_total += loss.detach() + loss1_total += loss1.detach() + loss2_total += loss2.detach() + loss3_total += loss3.detach() + commit_total += commit_loss.detach() + + # 每个 epoch 结束后更新学习率 + scheduler.step() + + data_length = len(dataloader) + tqdm.write( + f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + f"E: {epoch + 1} | Loss: {loss_total.item() / data_length:.6f} | " + f"L1: {loss1_total.item() / data_length:.6f} | " + f"L2: {loss2_total.item() / data_length:.6f} | " + f"L3: {loss3_total.item() / data_length:.6f} | " + f"VQ: {commit_total.item() / data_length:.6f} | " + f"LR: {scheduler.get_last_lr()[0]:.6f}" + ) + + # 每 CHECKPOINT 个 epoch 执行一次验证、可视化和检查点保存 + if (epoch + 1) % CHECKPOINT == 0: + losses = validate(dataloader_val, models, device, tile_dict, epoch + 1) + loss1_total, loss2_total, loss3_total, commit_total = losses + loss1_weighted = STAGE1_FOCAL_WEIGHT * loss1_total + loss2_weighted = STAGE2_FOCAL_WEIGHT * loss2_total + loss3_weighted = STAGE3_FOCAL_WEIGHT * loss3_total + commit_weighted = VQ_BETA * commit_total + loss_total = loss1_weighted + loss2_weighted + loss3_weighted + commit_weighted + + data_length = len(dataloader_val) + tqdm.write( + f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + f"E: {epoch + 1} | Loss: {loss_total.item() / data_length:.6f} | " + f"L1: {loss1_total.item() / data_length:.6f} | " + f"L2: {loss2_total.item() / data_length:.6f} | " + f"L3: {loss3_total.item() / data_length:.6f} | " + f"VQ: {commit_total.item() / data_length:.6f} | " + ) + + ckpt_path = f"result/seperated/sep-{epoch + 1}.pth" + torch.save({ + "epoch": epoch + 1, + "vq1": vq1.state_dict(), + "vq2": vq2.state_dict(), + "vq3": vq3.state_dict(), + "mg1": mg1.state_dict(), + "mg2": mg2.state_dict(), + "mg3": mg3.state_dict(), + "quantizer": quantizer.state_dict(), + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + }, ckpt_path) + tqdm.write(f"Saved checkpoint: {ckpt_path}") + + # 训练结束后保存最终完整权重(含优化器状态,可用于后续续训或推理) + final_path = "result/seperated.pth" + torch.save({ + "epoch": EPOCHS, + "vq1": vq1.state_dict(), + "vq2": vq2.state_dict(), + "vq3": vq3.state_dict(), + "mg1": mg1.state_dict(), + "mg2": mg2.state_dict(), + "mg3": mg3.state_dict(), + "quantizer": quantizer.state_dict(), + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + }, final_path) + tqdm.write(f"Training complete. Final model saved: {final_path}") diff --git a/ginka/train_stage.py b/ginka/train_stage.py index 6279e66..d1c7e4c 100644 --- a/ginka/train_stage.py +++ b/ginka/train_stage.py @@ -285,26 +285,26 @@ def make_random_struct_cond(): def make_stage_init(stage: int, context_map: torch.Tensor) -> torch.Tensor: """ - 根据阶段构造 MaskGIT 的推理初始地图。 + 根据阶段构造 MaskGIT 的推理初始地图(与训练端掩码策略一致)。 - Stage 1: 全 MASK(或保留稀疏 wall 种子) - Stage 2: 保留 floor/wall 上下文,其余 → MASK - Stage 3: 保留完整上下文(floor/wall/door/monster/entrance),resource → MASK + Stage 1: 全 MASK + Stage 2: 只保留 wall(1),floor + 功能元素 → MASK + Stage 3: 保留 wall(1)/door(2)/monster(4)/entrance(5),floor + resource → MASK """ init = context_map.clone() if stage == 1: - # 全 MASK(不依赖上下文地图) init = torch.full_like(init, MASK_TOKEN) elif stage == 2: - # 保留 floor/wall,其余 → MASK - mask = ~torch.isin(init, torch.tensor([0, 1], device=init.device)) - init[mask] = MASK_TOKEN + # 只保留 wall,其余全部 → MASK + keep = torch.isin(init, torch.tensor([1], device=init.device)) + init[~keep] = MASK_TOKEN else: # stage == 3 - # 保留非 resource,resource → MASK - init[init == 3] = MASK_TOKEN + # 保留 wall + 功能元素,floor + resource → MASK + keep = torch.isin(init, torch.tensor([1, 2, 4, 5], device=init.device)) + init[~keep] = MASK_TOKEN return init diff --git a/ginka/vqvae/model.py b/ginka/vqvae/model.py index ad2d56d..8d4842e 100644 --- a/ginka/vqvae/model.py +++ b/ginka/vqvae/model.py @@ -1,8 +1,9 @@ +import time import torch import torch.nn as nn from .quantize import VectorQuantizer from typing import Tuple - +from ..utils import print_memory class _DecodeLayer(nn.Module): """单个解码层:Pre-LN Cross-Attention + Pre-LN FFN。""" @@ -97,62 +98,13 @@ class VQDecodeHead(nn.Module): class GinkaVQVAE(nn.Module): - """ - VQ-VAE 风格地图编码器。 - - 将一张完整的地图([B, H*W] 整数 tile ID 序列)编码为 L 个离散码字, - 输出 z [B, L, d_z] 作为 MaskGIT 模型的生成条件。 - - 架构: - tile embedding + 位置编码 - → L 个可学习 summary token(拼接到序列头部) - → Transformer Encoder(Pre-LN,自注意力) - → 取前 L 个输出 - → 线性投影到 d_z - → VectorQuantizer(直通估计 + 熵最大化正则) - - 设计约束: - - 参数量目标 < 1M - - 不含解码器,z 的语义由 MaskGIT 端的交叉熵损失间接约束 - - z 定位为风格/多样性控制信号,而非结构重建指导 - """ - def __init__( - self, - num_classes: int = 16, - L: int = 2, - K: int = 16, - d_z: int = 64, - d_model: int = 128, - nhead: int = 4, - num_layers: int = 2, - dim_ff: int = 256, - map_size: int = 13 * 13, - beta: float = 0.25, - gamma: float = 0.1, - vq_temp: float = 1.0, + self, num_classes: int = 16, L: int = 2, K: int = 16, d_z: int = 64, d_model: int = 128, + nhead: int = 4, num_layers: int = 2, dim_ff: int = 256, map_size: int = 13 * 13 ): - """ - Args: - num_classes: tile 类别数(含 MASK token) - L: 码字序列长度,即 z 的序列维度 - K: codebook 大小(码字总数) - d_z: 码字嵌入维度 - d_model: Transformer 内部维度 - nhead: 注意力头数 - num_layers: Transformer 层数 - dim_ff: 前馈网络隐层维度 - map_size: 地图 token 总数(H * W) - beta: 承诺损失权重 - gamma: 熵正则损失权重 - vq_temp: VQ 软分配 softmax 温度 - """ super().__init__() self.L = L self.K = K - self.d_z = d_z - self.beta = beta - self.gamma = gamma # Tile 嵌入 self.tile_embedding = nn.Embedding(num_classes, d_model) @@ -165,13 +117,8 @@ class GinkaVQVAE(nn.Module): # Pre-LN Transformer Encoder(训练更稳定) encoder_layer = nn.TransformerEncoderLayer( - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_ff, - batch_first=True, - activation='gelu', - norm_first=True, # Pre-LN - dropout=0.0, + d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True, + activation='gelu', norm_first=True, dropout=0.1, ) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) @@ -181,127 +128,43 @@ class GinkaVQVAE(nn.Module): nn.LayerNorm(d_z), ) - # 向量量化层 - self.vq = VectorQuantizer(K=K, d_z=d_z, temp=vq_temp) - - def encode(self, map: torch.Tensor) -> torch.Tensor: - """ - 将地图编码为量化前的连续向量序列。 - - Args: - map: [B, H*W] 整数 tile ID - - Returns: - z_e: [B, L, d_z] 量化前的编码向量 - """ - B = map.shape[0] - - x = self.tile_embedding(map) # [B, H*W, d_model] - x = x + self.pos_embedding # [B, H*W, d_model] - - summary = self.summary_tokens.expand(B, -1, -1) # [B, L, d_model] - x = torch.cat([summary, x], dim=1) # [B, L+H*W, d_model] - - x = self.transformer(x) # [B, L+H*W, d_model] - - z_e = self.proj(x[:, :self.L]) # [B, L, d_z] - return z_e - - def encode_soft(self, soft_emb: torch.Tensor) -> torch.Tensor: - """ - 将软嵌入序列编码为量化前的连续向量序列(用于一致性约束)。 - - 与 encode() 的区别:输入是已经过 softmax 加权求和得到的连续嵌入矩阵 - [B, H*W, d_model],而非整数 tile ID。梯度可完整回传到调用方的 logits。 - - Args: - soft_emb: [B, H*W, d_model] softmax 加权 tile 嵌入(已在 d_model 空间) - - Returns: - z_e: [B, L, d_z] 量化前的编码向量 - """ - B = soft_emb.shape[0] - - x = soft_emb + self.pos_embedding # [B, H*W, d_model] - - summary = self.summary_tokens.expand(B, -1, -1) # [B, L, d_model] - x = torch.cat([summary, x], dim=1) # [B, L+H*W, d_model] - - x = self.transformer(x) # [B, L+H*W, d_model] - - z_e = self.proj(x[:, :self.L]) # [B, L, d_z] - return z_e - def forward(self, map: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - 完整前向传播:编码 → 量化 → 计算损失。 - - Args: - map: [B, H*W] 整数 tile ID(训练时传入完整真实地图) - - Returns: - z_q: [B, L, d_z] 量化后的 z(含直通梯度),供 MaskGIT 使用 - z_e: [B, L, d_z] 量化前的连续编码向量,供一致性约束使用 - indices: [B, L] 每个位置对应的码字索引 - vq_loss: scalar VQ 总损失 = beta * commit_loss + gamma * entropy_loss - commit_loss: scalar - entropy_loss: scalar - """ - z_e = self.encode(map) - z_q, indices, commit_loss, entropy_loss = self.vq(z_e) - - vq_loss = self.beta * commit_loss + self.gamma * entropy_loss - return z_q, z_e, indices, vq_loss, commit_loss, entropy_loss - - def sample(self, B: int, device: torch.device) -> torch.Tensor: - """ - 推理阶段:从 codebook 中随机均匀采样 L 个码字。 - - Args: - B: batch size - device: 目标设备 - - Returns: - z: [B, L, d_z] - """ - indices = torch.randint(0, self.K, (B, self.L), device=device) - z = self.vq.codebook(indices) # [B, L, d_z] - return z + # map: [B, H * W] + B, L = map.shape + x = self.tile_embedding(map) # [B, H * W, d_model] + x = x + self.pos_embedding # [B, H * W, d_model] + + summary = self.summary_tokens.expand(B, -1, -1) # [B, L, d_model] + x = torch.cat([summary, x], dim=1) # [B, L+H*W, d_model] + + x = self.transformer(x) + + z_e = self.proj(x[:, :self.L]) + + return z_e if __name__ == "__main__": - device = torch.device("cpu") + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + map_input = torch.randint(0, 7, (4, 13 * 13)).to(device) # [B=4, 169] model = GinkaVQVAE( - num_classes=16, - L=2, - K=16, - d_z=64, - d_model=128, - nhead=4, - num_layers=2, - dim_ff=256, - map_size=13 * 13, + num_classes=7, L=2, K=16, d_z=64, d_model=128, + nhead=4, num_layers=2, dim_ff=256, map_size=13 * 13, ).to(device) - total_params = sum(p.numel() for p in model.parameters()) - print(f"总参数量: {total_params:,} ({total_params / 1e6:.3f}M)") + print_memory(device, "初始化后") - # 分模块参数统计 - for name, module in model.named_children(): - n = sum(p.numel() for p in module.parameters()) - print(f" {name}: {n:,}") + start = time.perf_counter() + z_e = model(map_input) + end = time.perf_counter() - # 前向传播测试 - map_input = torch.randint(0, 15, (4, 13 * 13)).to(device) # [B=4, 169] + print_memory(device, "前向传播后") - z_q, z_e, indices, vq_loss, commit_loss, entropy_loss = model(map_input) - - print(f"\nz_q shape: {z_q.shape}") # [4, 2, 64] - print(f"z_e shape: {z_e.shape}") # [4, 2, 64] - print(f"indices shape:{indices.shape}") # [4, 2] - print(f"vq_loss: {vq_loss.item():.4f}") - - # 推理采样测试 - z_sample = model.sample(B=4, device=device) - print(f"sample shape: {z_sample.shape}") # [4, 2, 64] + print(f"推理耗时: {end - start:.4f}s") + print(f"输出形状: z_e={z_e.shape}") + print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}") + print(f"Transformer parameters: {sum(p.numel() for p in model.transformer.parameters())}") + print(f"Projection parameters: {sum(p.numel() for p in model.proj.parameters())}") + print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") diff --git a/ginka/vqvae/quantize.py b/ginka/vqvae/quantize.py index 6f23e11..7a8c0ab 100644 --- a/ginka/vqvae/quantize.py +++ b/ginka/vqvae/quantize.py @@ -3,20 +3,8 @@ import torch.nn as nn import torch.nn.functional as F from typing import Tuple - class VectorQuantizer(nn.Module): - """ - 向量量化层(Vector Quantization)。 - - 将连续的编码向量序列映射到离散的 codebook 码字索引, - 并通过直通估计(Straight-Through Estimator)保持梯度流。 - - 均匀分布正则化采用软分配熵最大化方案: - 通过对距离做 softmax 得到软分配概率,计算平均码字使用率的熵, - 最小化负熵以鼓励所有码字被均等使用。 - """ - - def __init__(self, K: int, d_z: int, temp: float = 1.0): + def __init__(self, K: int, d_z: int): """ Args: K: codebook 大小(码字数量) @@ -26,12 +14,12 @@ class VectorQuantizer(nn.Module): super().__init__() self.K = K self.d_z = d_z - self.temp = temp self.codebook = nn.Embedding(K, d_z) nn.init.uniform_(self.codebook.weight, -1.0 / K, 1.0 / K) def forward(self, z_e: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # z_e: [B, L * 3, d_z] """ Args: z_e: [B, L, d_z] 编码器输出的连续向量序列 @@ -40,28 +28,25 @@ class VectorQuantizer(nn.Module): z_q_st: [B, L, d_z] 量化后向量(直通梯度) indices: [B, L] 每个位置对应的码字索引 commit_loss: scalar 承诺损失 ||z_e - sg(z_q)||^2 - entropy_loss: scalar 负熵损失(最小化 = 最大化码字使用均匀度) """ B, L, d_z = z_e.shape - # 展平到 [B*L, d_z] - z_flat = z_e.reshape(B * L, d_z) + z_flat = z_e.reshape(B * L, d_z) # [B * L * 3, d_z] codebook_w = self.codebook.weight # [K, d_z] # 计算 L2 距离:||z_e - e_k||^2 = ||z_e||^2 + ||e_k||^2 - 2 * z_e · e_k # distances: [B*L, K] - distances = ( - (z_flat ** 2).sum(dim=1, keepdim=True) # [B*L, 1] - + (codebook_w ** 2).sum(dim=1) # [K] - - 2.0 * z_flat @ codebook_w.t() # [B*L, K] - ) + ze_square = torch.sum(z_flat ** 2, dim=1, keepdim=True) + ek_square = torch.sum(codebook_w ** 2, dim=1) + mul = z_flat @ codebook_w.t() + distances = ze_square + ek_square - 2 * mul # Hard assignment:取最近码字索引 - indices = distances.argmin(dim=1) # [B*L] + indices = distances.argmin(dim=1) # [B*L] # 量化向量 - z_q_flat = self.codebook(indices) # [B*L, d_z] + z_q_flat = self.codebook(indices) # [B*L, d_z] z_q = z_q_flat.reshape(B, L, d_z) # 直通估计:前向传 z_q,反向传 z_e 的梯度 @@ -70,12 +55,14 @@ class VectorQuantizer(nn.Module): # 承诺损失:拉近编码向量与其对应的码字(仅更新编码器) commit_loss = F.mse_loss(z_e, z_q.detach()) - # 熵最大化正则:通过软分配计算平均码字使用率,最小化负熵 - # soft_assign: [B*L, K],对距离做 softmax(距离越小,概率越大) - soft_assign = F.softmax(-distances / self.temp, dim=1) - avg_assign = soft_assign.mean(dim=0) # [K],平均码字使用率 - # entropy_loss = -H(p) = sum(p * log(p)),最小化即最大化熵 - entropy_loss = (avg_assign * torch.log(avg_assign + 1e-10)).sum() - indices = indices.reshape(B, L) - return z_q_st, indices, commit_loss, entropy_loss + return z_q_st, indices, commit_loss + + def sample(self, B: int, L: int, device: torch.device) -> torch.Tensor: + indices1 = torch.randint(0, self.K, (B, L), device=device) + indices2 = torch.randint(0, self.K, (B, L), device=device) + indices3 = torch.randint(0, self.K, (B, L), device=device) + z1 = self.codebook(indices1) + z2 = self.codebook(indices2) + z3 = self.codebook(indices3) + return torch.cat([z1, z2, z3], dim=1) diff --git a/prompt.md b/prompt.md new file mode 100644 index 0000000..0897989 --- /dev/null +++ b/prompt.md @@ -0,0 +1,76 @@ +# Ginka 地图生成器 - Copilot 指引 + +## 项目概述 + +本项目是一个基于深度学习的二维网格状地图生成模型,用于生成魔塔(Magic Tower)类网页游戏地图。 + +- **模型结构**:VQ-VAE 风格编码器 + MaskGIT 解码器 + - VQ-VAE 编码器将完整地图压缩为离散隐变量 z(从 codebook 查得) + - MaskGIT 以 z 为条件,通过迭代掩码预测生成地图 + - 推理时直接随机采样 z,无需用户输入 +- **地图规格**:13×13 格子,7 类图块 +- **目录结构** + - `ginka/` — 模型定义与训练脚本(Python) + - `data/` — 数据预处理(TypeScript,因游戏是网页游戏) + - `docs/` — 设计文档 + - `shared/` — 可视化等共享工具 + +## 重要约束 + +### 训练 + +- **不要在当前设备上运行训练**,训练在其他设备上进行 +- 可以运行小规模验证、推理或单步测试,但不要触发完整训练流程 + +### 代码风格 + +#### Python + +- 不使用三引号注释(`"""..."""`),一律改用 `#` 注释 +- 不出现连续空行(即空行仅允许连续出现一行)不出现连续空格,例如下面的例子就不允许出现: + + ```python + a = func1() + abcdef = func2() + ``` + + 应改为: + + ```python + a = func1() + abcdef = func2() + ``` + +- 遵循类似 Prettier 的风格,不出现尾逗号。 +- 不进行无意义的对齐,例如函数参数定义应该遵循这种风格,到达 80 字符左右换行: + + ```python + def func( + param1: type, param2: type, param3: type, + param4: type, param5: type + ) + ``` + + 而不是: + + ```python + def func(param1: type, param2: type, param3: type, + param4: type, param5: type) + ``` + +- 不使用下划线开头命名任何内容,包括私有方法。 +- 不写静态方法。 +- 仅允许在文件开头引入内容,不允许其他地方出现任何 `import`。 +- 文件尾添加空行。 +- 不允许出现连等。 +- 不允许使用元组语法同时给多个量分别赋值,比如 `a, b, c = d, e, f` 不允许出现,仅允许 `a, b, c = func()` 这种一赋多的场景。 +- 不要在文件开头添加注释,开头第一句应该是 `import`。文件注释应该在 `import` 之后写。 + +#### TypeScript + +遵循 Prettier 风格。 + +### 验证与可视化 + +- 编写验证代码时,优先输出可视化结果(图片文件),使用 `shared/image.py` 中的工具 +- 验证阶段应对不同条件(不同 z 采样)分别生成图片,便于直观对比模型效果