From 3676958781e439ffed8c8e9b795f5dfc5bac6b3a Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Thu, 7 May 2026 20:59:22 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=88=86=E4=B8=89=E9=98=B6=E6=AE=B5?= =?UTF-8?q?=E8=AE=AD=E7=BB=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Copilot --- docs/three-stage-generation-design.md | 499 ++++++++++++++++++++ ginka/dataset.py | 375 ++++++++++++++- ginka/maskGIT/model.py | 16 +- ginka/train_stage.py | 654 ++++++++++++++++++++++++++ 4 files changed, 1538 insertions(+), 6 deletions(-) create mode 100644 docs/three-stage-generation-design.md create mode 100644 ginka/train_stage.py diff --git a/docs/three-stage-generation-design.md b/docs/three-stage-generation-design.md new file mode 100644 index 0000000..032111a --- /dev/null +++ b/docs/three-stage-generation-design.md @@ -0,0 +1,499 @@ +# 三阶段级联地图生成设计文档 + +## 背景与问题诊断 + +### 当前问题:墙壁过度生成 + +现有单模型方案(VQ-VAE + MaskGIT 联合训练)在实践中呈现出**模型严重偏向生成墙壁**的现象,具体表现为: + +- 生成地图中墙壁密度显著偏高,可通行空间不足; +- 门、怪物、入口等功能性元素稀少甚至缺失; +- 资源类 tile 几乎不出现。 + +### 根本原因分析 + +从 token 分布角度来看,13×13 = 169 个格子中,各类 tile 的分布严重不均: + +| tile 类型 | 典型比例 | 训练信号 | +| ------------- | -------- | -------- | +| 墙壁(wall) | 40%~60% | 强、密集 | +| 空地(floor) | 20%~40% | 强、密集 | +| 门(door) | 1%~5% | 极弱 | +| 怪物 | 5%~15% | 弱 | +| 入口 | 1%~3% | 极弱 | +| 资源 | 5%~15% | 弱 | + +单模型将以上所有类别的 **交叉熵损失**混合在一起优化。由于墙壁和空地占据绝大多数 token,模型可以通过反复预测墙壁/空地获得低训练损失,而无需真正学习如何放置稀有类别。 + +这是**类别不均衡问题的结构性体现**:即使引入 Focal Loss,调整 loss 权重,也难以从根本上解决三类任务学习难度不匹配的问题——因为结构骨架(floor/wall)是其他所有元素放置的前提,混合训练会导致模型困于局部最优解("把所有位置都预测为墙壁")。 + +--- + +## 核心思路:三阶段级联生成 + +将单次全类别生成拆分为**三个独立的 MaskGIT 阶段**,每阶段只负责一组语义相近、结构约束相似的 tile 类别。后续阶段以前序阶段的输出作为已知上下文。 + +``` +┌──────────────────────────────────────────────────────────────────────┐ +│ 阶段一:结构骨架生成 │ +│ │ +│ 全 MASK ──► [Stage1-MaskGIT + z₁] ──► floor/wall 地图 │ +└──────────────────────────────────────────────────────────────────────┘ + │ + ▼ 已知 floor/wall 上下文 +┌──────────────────────────────────────────────────────────────────────┐ +│ 阶段二:功能元素放置 │ +│ │ +│ floor/wall + MASK ──► [Stage2-MaskGIT + z₂] ──► door/monster/入口│ +└──────────────────────────────────────────────────────────────────────┘ + │ + ▼ 已知 floor/wall/door/monster/入口 上下文 +┌──────────────────────────────────────────────────────────────────────┐ +│ 阶段三:资源放置 │ +│ │ +│ 完整上下文 + MASK ──► [Stage3-MaskGIT + z₃] ──► resource │ +└──────────────────────────────────────────────────────────────────────┘ +``` + +### 各阶段职责 + +| 阶段 | 负责类别 | tile 数 | 类别数(含 MASK) | 结构约束强度 | +| ------ | -------------------------------- | ------- | ----------------- | ------------ | +| 阶段一 | floor(0)、wall(1) | 2 | 3 | 极强 | +| 阶段二 | door(2)、monster(4)、entrance(5) | 3 | 6 | 强 | +| 阶段三 | resource(3) | 1 | 3 | 弱 | + +注:各阶段模型的词表不需要只含本阶段 tile,完整词表(7 类)可以保留不变,只是**被 MASK 的位置**和**计算 Loss 的位置**会有所不同(见训练策略)。 + +--- + +## 每阶段 tile 映射与 MASK 策略 + +### 阶段一 + +**输入**:所有位置均填充 MASK token(`tile=6`)。 + +**目标**:预测 floor(0) 和 wall(1);原始地图中所有非 floor/wall 的 tile 在训练目标中被**重映射为 floor(0)**,因为阶段一模型不关心功能性元素的具体种类。 + +```python +STAGE1_REMAP = { + 0: 0, # floor → floor + 1: 1, # wall → wall + 2: 0, # door → floor(视作空地,让阶段二填充) + 3: 0, # resource → floor + 4: 0, # monster → floor + 5: 0, # entrance → floor + 6: 6, # MASK → MASK(不参与损失计算) +} +``` + +**Loss 计算范围**:所有非 MASK 位置(即全局所有位置,因为输入均为 MASK)。 + +### 阶段二 + +**输入**:将阶段一输出的 floor/wall 地图作为固定上下文,原始地图中属于 door/monster/entrance 的位置替换为 MASK token,其余(floor/wall)位置保持不变。 + +```python +STAGE2_MASK_IDS = {2, 4, 5} # 需要在阶段二中被预测的 tile ID +``` + +**训练时输入构造**(使用 GT 地图中的 floor/wall 作为上下文,不使用阶段一的实际输出): + +```python +def make_stage2_input(gt_map: np.ndarray) -> np.ndarray: + """ + gt_map: [H*W] 整数数组,包含完整原始地图 + 返回: stage2 输入地图,door/monster/entrance 位置替换为 MASK + """ + inp = gt_map.copy() + # 先将所有资源归一化为 floor(阶段二不负责资源) + inp[np.isin(inp, [3])] = 0 + # 将 door/monster/entrance 位置 MASK 掉 + inp[np.isin(inp, [2, 4, 5])] = 6 # MASK + return inp +``` + +**目标**:预测 door(2)、monster(4)、entrance(5) 位置的类别(以及它们所在位置原本是否是 floor)。 + +**Loss 计算范围**:仅对输入为 MASK(即原本是 door/monster/entrance 的位置)计算损失,floor/wall 位置不参与损失(它们已经确定)。 + +### 阶段三 + +**输入**:将阶段二输出的完整功能性地图(floor/wall/door/monster/entrance)作为固定上下文,原始地图中资源(tile=3)位置替换为 MASK token。 + +```python +def make_stage3_input(gt_map: np.ndarray) -> np.ndarray: + """ + gt_map: [H*W] 整数数组,包含完整原始地图 + 返回: stage3 输入地图,resource 位置替换为 MASK + """ + inp = gt_map.copy() + inp[inp == 3] = 6 # resource → MASK + return inp +``` + +**Loss 计算范围**:仅对输入为 MASK(即原本是 resource 的位置)计算损失。 + +--- + +## 模型架构 + +### 共享基础架构 + +三个阶段均采用与现有方案**相同的 MaskGIT + VQ-VAE 架构**,不需要设计新的模型结构。核心差异在于: + +1. **输入词表**:与现有方案一致,均使用 7 类词表(`NUM_CLASSES=7`,`MASK_TOKEN=6`); +2. **z 来源**:每阶段使用来自对应 VQ 通道的隐变量; +3. **Loss 掩码**:只对该阶段负责的位置计算 CE Loss。 + +### VQ-VAE z 的阶段分配 + +现有 VQ-VAE 已采用三通道设计(CH1/CH2/CH3),与三个生成阶段自然对应: + +| 通道 | 编码目标 | 供应阶段 | 描述 | +| ---- | ----------------- | -------- | ---------------------------- | +| CH1 | floor/wall 骨架 | 阶段一 | z₁,控制墙壁结构多样性 | +| CH2 | door/monster/入口 | 阶段二 | z₂,控制功能元素布局多样性 | +| CH3 | resource 分布 | 阶段三 | z₃,控制资源密度与位置多样性 | + +``` +真实地图 ──► CH1 encoder ──► z₁ ──► Stage1 MaskGIT + ──► CH2 encoder ──► z₂ ──► Stage2 MaskGIT + ──► CH3 encoder ──► z₃ ──► Stage3 MaskGIT +``` + +**通道专属编码**:各通道编码器在编码前只"看"与自身相关的 tile,其余 tile 视作 floor(0): + +```python +def ch1_mask(gt_map): + """只保留 floor/wall,其余置 0""" + m = gt_map.copy() + m[~np.isin(m, [0, 1])] = 0 + return m + +def ch2_mask(gt_map): + """只保留 door/monster/entrance,其余置 0""" + m = gt_map.copy() + m[~np.isin(m, [2, 4, 5])] = 0 + return m + +def ch3_mask(gt_map): + """只保留 resource,其余置 0""" + m = gt_map.copy() + m[m != 3] = 0 + return m +``` + +### Loss 计算掩码实现 + +训练时,每阶段额外接收一个 `loss_mask: torch.BoolTensor [B, H*W]`,指示哪些位置需要计算损失: + +```python +# 阶段一:所有位置(因为输入全为 MASK) +loss_mask_s1 = torch.ones(B, MAP_SIZE, dtype=torch.bool) + +# 阶段二:只有原本是 door/monster/entrance 的位置 +loss_mask_s2 = torch.isin(raw_map, torch.tensor([2, 4, 5])) + +# 阶段三:只有原本是 resource 的位置 +loss_mask_s3 = (raw_map == 3) +``` + +Focal Loss 修改为只对 `loss_mask` 为 True 的位置求和后做归一化: + +```python +def stage_focal_loss(logits, targets, loss_mask, gamma=2.0): + # logits: [B, C, H*W], targets: [B, H*W], loss_mask: [B, H*W] + per_token_loss = focal_loss(logits, targets, gamma, reduction='none') # [B, H*W] + masked_loss = per_token_loss[loss_mask] + return masked_loss.mean() if masked_loss.numel() > 0 else per_token_loss.mean() +``` + +--- + +## 训练策略 + +### 训练方式:顺序训练(推荐) + +三个阶段**依次训练**,后续阶段训练时使用 GT 地图作为前序阶段的"完美输出"(teacher forcing),而非使用前序阶段模型的实际推理结果。 + +``` +训练阶段一: + data: (全MASK输入, stage1目标) + loss: focal loss on all positions + +训练阶段二: + data: (floor/wall上下文 + MASK输入, stage2目标) + loss: focal loss only on door/monster/entrance positions + +训练阶段三: + data: (floor/wall/door/monster/入口上下文 + MASK输入, stage3目标) + loss: focal loss only on resource positions +``` + +**使用 GT 而非前序模型输出的理由**: + +- 避免误差级联(前序模型若生成错误的骨架,后续模型的训练将在错误分布上进行); +- 各阶段训练更稳定,收敛更快; +- 阶段之间解耦,方便单独迭代和调试。 + +### 各阶段训练子集划分 + +每个阶段均沿用现有 A/B/C/D 子集划分逻辑,但 MASK 策略应用在对应阶段的目标 tile 上: + +| 子集 | 阶段一 | 阶段二 | 阶段三 | +| ---- | ----------------------------- | ------------------------------ | --------------------------- | +| A | 随机遮盖部分 floor/wall | 随机遮盖部分 door/monster/入口 | 随机遮盖部分 resource | +| B | 保留全部 wall,MASK floor | 给定全部 wall,MASK 功能元素 | 给定全部骨架,MASK 部分资源 | +| C | 随机保留部分 wall,MASK 其余 | 同 B | 同 B | +| D | 保留 wall+entrance,MASK 其余 | 给定 wall+entrance,MASK 门/怪 | 同 B | + +### 各阶段专属损失 + +每阶段的总损失: + +$$\mathcal{L}^{(s)} = \mathcal{L}_{CE}^{(s)} + \beta \cdot \mathcal{L}_{commit}^{(s)} + \gamma \cdot \mathcal{L}_{uniform}^{(s)}$$ + +其中 $s \in \{1, 2, 3\}$ 表示阶段编号,$\mathcal{L}_{CE}^{(s)}$ 只在该阶段负责的 tile 位置上计算。 + +--- + +## 推理流程 + +### 完整推理管线 + +``` +1. 随机采样 z₁, z₂, z₃(各自从对应 codebook 均匀采样 L 个 index) + +2. 阶段一推理(结构骨架) + 初始状态: 全部 169 个位置 = MASK token + 迭代 MaskGIT 解码(cosine schedule,约 18 步): + 输入: MASK地图 + z₁ + 输出: 逐步填充 floor/wall,直到无 MASK 位置 + 结果: floor/wall 骨架地图 M₁ + +3. 阶段二推理(功能元素放置) + 初始状态: 继承 M₁,在 floor 位置随机选取候选位置置为 MASK + ─── 或 ─── + 初始状态: 继承 M₁,所有 floor 位置均置为 MASK(让模型决定放置密度) + 迭代 MaskGIT 解码: + 输入: 含已知 wall 的掩码地图 + z₂ + 约束: wall 位置不参与 unmask,保持不变 + 输出: 逐步填充 door/monster/entrance + 结果: 含功能元素的地图 M₂ + +4. 阶段三推理(资源放置) + 初始状态: 继承 M₂,所有 floor 位置(未被阶段二填充的)置为 MASK + 迭代 MaskGIT 解码: + 输入: 含已知 wall/door/monster/入口的掩码地图 + z₃ + 约束: 非 floor 位置保持不变 + 输出: 逐步填充 resource(或保持 floor) + 结果: 完整地图 M₃ +``` + +### 阶段二初始 MASK 策略 + +阶段二的初始状态有两种选择: + +| 策略 | 描述 | 适用场景 | +| ------------- | -------------------------------------------------- | ---------------- | +| 全 floor MASK | 所有 floor 位置均置为 MASK,让模型自主决定放置密度 | 完全随机生成 | +| 候选位置 MASK | 只 MASK 用户指定或随机抽取的少数位置 | 用户指定部分位置 | + +对于完全随机生成场景,**推荐使用全 floor MASK**——此时 z₂ 决定功能元素的总体风格(密集/稀疏/集中/分散),模型负责在此约束下寻找合理位置。 + +### 用户交互场景 + +| 场景 | 阶段一输入 | 阶段二输入 | 阶段三输入 | +| ------------------- | -------------------- | ------------------------ | ---------- | +| 完全随机 | 全 MASK | 继承 M₁ | 继承 M₂ | +| 用户手绘墙壁 | 已知 wall + MASK | 继承 M₁(固定 wall) | 继承 M₂ | +| 用户指定入口 | 已知 entrance + MASK | 继承 M₁(固定 entrance) | 继承 M₂ | +| 用户手绘墙+指定入口 | 已知 wall/entrance | 继承 M₁ | 继承 M₂ | + +--- + +## 实现方案 + +### 方案一:三个独立 GinkaMaskGIT 实例(推荐) + +每个阶段分别实例化一个 `GinkaMaskGIT`,共用同一个 `GinkaVQVAE`(其中已含三通道编码器)。各阶段模型独立训练、独立存储。 + +```python +# 模型结构 +vqvae = GinkaVQVAE(...) # 三通道共享编码器 +stage1 = GinkaMaskGIT(...) # 结构骨架 +stage2 = GinkaMaskGIT(...) # 功能元素 +stage3 = GinkaMaskGIT(...) # 资源 +``` + +**优点**: + +- 各阶段模型完全解耦,可独立调整超参、单独重训; +- 模型大小可以针对任务难度调整(阶段一可以更大,阶段三可以更小); +- 便于调试和增量式开发。 + +**缺点**: + +- 三个模型分别保存,推理时需依次加载; +- 总参数量约为原方案的 3 倍(但可以通过缩小各阶段模型来对冲)。 + +### 方案二:单模型 + 阶段 Embedding(备选) + +复用同一个 `GinkaMaskGIT`,添加阶段 embedding(类似 BERT 的 segment embedding): + +```python +self.stage_embedding = nn.Embedding(3, d_model) # 三个阶段 + +def forward(self, map, z, stage: int): + x = self.tile_embedding(map) + self.pos_embedding + x = x + self.stage_embedding(torch.tensor(stage)) # 阶段条件 + ... +``` + +**优点**:参数量与原方案相同,推理时只需加载一个模型。 +**缺点**:三阶段共享所有权重,可能导致阶段间干扰;阶段一的结构任务与阶段三的稀疏资源任务表示空间差异大,单一模型难以同时擅长。 + +**结论**:**推荐方案一**,尤其是在当前阶段(验证多阶段框架可行性),可先分别训练三个轻量模型进行快速验证。 + +### 各阶段模型规模建议 + +| 阶段 | 任务难度 | 建议 d_model | 建议 num_layers | 参数量估算 | +| ------ | ---------- | ------------ | --------------- | ---------- | +| 阶段一 | 高(结构) | 256 | 6 | ~4M | +| 阶段二 | 中(功能) | 192 | 4 | ~2M | +| 阶段三 | 低(稀疏) | 128 | 3 | ~0.8M | + +--- + +## Dataset 修改方案 + +### 新增 `GinkaStageDataset` + +需要扩展 `dataset.py`,增加针对三阶段的 Dataset 类,或在 `GinkaVQDataset` 中添加 `stage` 参数: + +```python +class GinkaStageDataset(Dataset): + """ + 三阶段级联训练专用 Dataset。 + + 返回 dict: + raw_map: LongTensor [H*W] 完整原始地图(供 VQ-VAE 编码) + stage_input: LongTensor [H*W] 当前阶段 MaskGIT 输入(含上下文 + MASK) + target_map: LongTensor [H*W] CE loss ground truth(等同 raw_map) + loss_mask: BoolTensor [H*W] 只对 True 位置计算损失 + subset: str 子集标识 + """ + + STAGE1_TARGETS = {0, 1} # floor, wall + STAGE2_TARGETS = {2, 4, 5} # door, monster, entrance + STAGE3_TARGETS = {3} # resource +``` + +### 数据构造函数 + +```python +def make_stage1_sample(gt_map: np.ndarray, mask_id: int = 6): + """阶段一:全 MASK 输入,目标是 floor/wall(其余归一为 floor)""" + stage_input = np.full_like(gt_map, mask_id) + target = gt_map.copy() + target[~np.isin(target, [0, 1])] = 0 # 非结构 tile → floor + loss_mask = np.ones_like(gt_map, dtype=bool) + return stage_input, target, loss_mask + +def make_stage2_sample(gt_map: np.ndarray, mask_id: int = 6): + """阶段二:floor/wall 为上下文,door/monster/entrance 位置 MASK""" + stage_input = gt_map.copy() + stage_input[stage_input == 3] = 0 # 资源 → floor(阶段二不负责) + target_ids = np.isin(stage_input, [2, 4, 5]) + stage_input[target_ids] = mask_id # 功能元素 → MASK + target = gt_map.copy() + target[target == 3] = 0 # target 中资源也视为 floor + loss_mask = (gt_map != 0) & (gt_map != 1) & (gt_map != 3) # 只计算功能元素位置 + return stage_input, target, loss_mask + +def make_stage3_sample(gt_map: np.ndarray, mask_id: int = 6): + """阶段三:全上下文保留,只 MASK 资源位置""" + stage_input = gt_map.copy() + stage_input[stage_input == 3] = mask_id # 资源 → MASK + target = gt_map.copy() + loss_mask = (gt_map == 3) # 只计算资源位置 + return stage_input, target, loss_mask +``` + +--- + +## 训练脚本设计 + +### 新增 `ginka/train_stage.py` + +``` +用法示例: + python -m ginka.train_stage --stage 1 + python -m ginka.train_stage --stage 2 + python -m ginka.train_stage --stage 3 --resume True --state result/stage3/stage3-10.pth +``` + +各阶段检查点分别存储到: + +- `result/stage1/stage1-{epoch}.pth` +- `result/stage2/stage2-{epoch}.pth` +- `result/stage3/stage3-{epoch}.pth` + +### 阶段二/三的 VQ 编码器冻结策略 + +阶段二训练时,CH1 编码器(已在阶段一训练中充分收敛)可选择冻结,只更新 CH2 编码器和 Stage2 MaskGIT: + +```python +if args.stage == 2: + for p in vqvae.encoder_ch1.parameters(): + p.requires_grad_(False) # 冻结 CH1 +``` + +类似地,阶段三训练时可冻结 CH1 和 CH2 编码器。 + +--- + +## 与现有方案的对比 + +| 维度 | 现有单模型方案 | 三阶段级联方案 | +| -------------- | ------------------------ | ---------------------------------------- | +| 墙壁过度生成 | 存在,难以从根本解决 | 阶段一单独训练骨架,Loss 聚焦 floor/wall | +| 训练信号均衡性 | 墙壁主导,稀有类欠拟合 | 各阶段 Loss 只计算本阶段 tile,信号均衡 | +| 模型可调试性 | 单一模型,各类别相互干扰 | 各阶段独立,可单独分析每阶段表现 | +| 推理速度 | 1 次完整 MaskGIT 解码 | 3 次级联解码(总步数约为原来 3 倍) | +| 误差累积 | 无 | 存在,前序阶段错误会传播到后续阶段 | +| 用户可控性 | 较难(条件混合) | 好(可在任意阶段注入用户约束) | +| 参数量 | ~4M | 约 ~7M(可通过缩减各阶段规模控制) | + +--- + +## 预期收益 + +1. **解决墙壁过度生成**:阶段一专门针对 floor/wall 训练,类别分布从 7 类压缩到 3 类,Loss 完全聚焦,模型不再有逃避路径; +2. **功能元素召回率提升**:阶段二以已知骨架为前提生成 door/monster/entrance,训练信号不再被墙壁噪声稀释; +3. **资源分布更合理**:阶段三在完整上下文下放置资源,能感知到门/怪物位置,避免资源与关键功能元素重叠; +4. **可交互性增强**:用户可在任意阶段注入约束(固定某些 tile),天然支持层次化编辑。 + +--- + +## 风险与应对 + +| 风险 | 描述 | 应对策略 | +| ---------------- | ----------------------------------------- | -------------------------------------------------------- | +| 误差累积 | 阶段一骨架不准确会导致阶段二/三布局失真 | 阶段一优先保证质量,推理时对骨架做后处理校验 | +| 推理耗时增加 | 三次 MaskGIT 解码约 3 倍耗时 | 减少 MaskGIT 迭代步数(阶段二/三任务更简单,步数可减半) | +| 阶段二稀疏性问题 | 169 格子中功能元素极少,Loss 计算覆盖率低 | 适当提高 stage 2 的 loss_mask 覆盖(周边 floor 也计入) | +| Codebook 对齐 | 三通道 VQ-VAE 的 z 在分阶段训练时各自优化 | 联合训练阶段一+VQ-CH1,联合训练阶段二+VQ-CH2,以此类推 | + +--- + +## 实施顺序 + +- [ ] 新增 `GinkaStageDataset`,实现三阶段数据构造函数 +- [ ] 新增 `ginka/train_stage.py`,支持 `--stage 1/2/3` 参数 +- [ ] 阶段一训练:仅 floor/wall,验证骨架生成质量(关键里程碑) +- [ ] 阶段二训练:以 GT floor/wall 为上下文,验证功能元素召回率 +- [ ] 阶段三训练:以 GT 完整地图为上下文,验证资源放置合理性 +- [ ] 实现三阶段级联推理脚本,接入现有可视化工具 +- [ ] 对比实验:三阶段方案 vs 现有单模型方案在墙壁密度、功能元素召回率、资源分布等指标上的差异 diff --git a/ginka/dataset.py b/ginka/dataset.py index e2d9c12..6d3b834 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -4,14 +4,69 @@ import torch import numpy as np from torch.utils.data import Dataset +def _compute_map_labels(map_2d) -> dict: + """ + 从 2D 地图列表(或 numpy 数组)推算结构标签。 + 当 JSON 数据缺少 roomCount / highDegBranchCount / outerWall 字段时调用。 + """ + arr = np.array(map_2d, dtype=np.int64) # [H, W] + H, W = arr.shape + WALL, ENTRY = 1, 5 + + # outerWall:最外圈中 wall+entry 占比 > 90% + border = np.concatenate([arr[0, :], arr[-1, :], arr[1:-1, 0], arr[1:-1, -1]]) + total_b = border.size + outer_wall = int(total_b > 0 and np.sum((border == WALL) | (border == ENTRY)) / total_b > 0.9) + + # roomCount:BFS 统计 floor(0)+resource(3) 连通区域, + # 需满足:总格子 >= 4,外接矩形宽 >= 2 且高 >= 2 + FLOOR_SET = (0, 3) + visited = np.zeros((H, W), dtype=bool) + room_count = 0 + for sy in range(H): + for sx in range(W): + if arr[sy, sx] not in FLOOR_SET or visited[sy, sx]: + continue + queue = [(sy, sx)] + visited[sy, sx] = True + tiles_y, tiles_x = [sy], [sx] + head = 0 + while head < len(queue): + y, x = queue[head]; head += 1 + for dy, dx in ((-1, 0), (1, 0), (0, -1), (0, 1)): + ny, nx = y + dy, x + dx + if 0 <= ny < H and 0 <= nx < W and not visited[ny, nx] and arr[ny, nx] in FLOOR_SET: + visited[ny, nx] = True + queue.append((ny, nx)) + tiles_y.append(ny); tiles_x.append(nx) + if (len(tiles_y) >= 4 + and max(tiles_y) - min(tiles_y) >= 1 + and max(tiles_x) - min(tiles_x) >= 1): + room_count += 1 + + # highDegBranchCount:非 wall 格子中,4 邻域非 wall 邻居 >= 3 的数量 + non_wall = (arr != WALL).astype(np.int32) + padded = np.pad(non_wall, 1, mode='constant', constant_values=0) + nbr_sum = (padded[:-2, 1:-1] + padded[2:, 1:-1] + + padded[1:-1, :-2] + padded[1:-1, 2:]) + high_deg = int(np.sum((non_wall == 1) & (nbr_sum >= 3))) + + return {'outerWall': outer_wall, 'roomCount': room_count, 'highDegBranchCount': high_deg} + + def load_data(path: str): with open(path, 'r', encoding="utf-8") as f: data = json.load(f) - + data_list = [] for value in data["data"].values(): + # 兼容旧版数据集(缺少结构标签字段) + if 'roomCount' not in value: + labels = _compute_map_labels(value['map']) + value.update(labels) + # symmetry 字段由 __getitem__ 在增强后重新计算,此处不需要从 JSON 读取 data_list.append(value) - + return data_list def _compute_symmetry(target_np: np.ndarray) -> tuple: @@ -326,6 +381,322 @@ class GinkaSplitDataset(Dataset): } +# --------------------------------------------------------------------------- +# GinkaStageDataset:三阶段级联训练专用数据集 +# --------------------------------------------------------------------------- + +class GinkaStageDataset(Dataset): + """ + 三阶段级联生成训练专用 Dataset。 + + 每个阶段只预测特定类别的 tile,后续阶段以前序阶段输出作为上下文。 + 训练时统一使用 GT 作为前序上下文(teacher forcing),避免误差级联。 + + 阶段划分: + stage=1 结构骨架:预测 floor(0) + wall(1) + stage=2 功能元素:预测 door(2) + monster(4) + entrance(5),以 floor/wall 为上下文 + stage=3 资源放置:预测 resource(3),以完整骨架为上下文 + + 返回 dict: + raw_map: LongTensor [H*W] 完整原始地图(供 VQ-VAE 编码) + vq_slice: LongTensor [H*W] 当前阶段 VQ 编码器的输入切片 + stage_input: LongTensor [H*W] MaskGIT 输入(含上下文 + MASK 位置) + target_map: LongTensor [H*W] CE loss ground truth + loss_mask: BoolTensor [H*W] 只对 True 位置计算损失 + subset: str 子集标识 A/B/C/D + struct_cond: LongTensor [4] [sym, room, branch, outer] + """ + + FLOOR = 0 + WALL = 1 + DOOR = 2 + RESOURCE = 3 + MONSTER = 4 + ENTRANCE = 5 + MASK_ID = 6 + + STAGE1_TARGETS = frozenset({0, 1}) + STAGE2_TARGETS = frozenset({2, 4, 5}) + STAGE3_TARGETS = frozenset({3}) + + # VQ 切片集合:各阶段编码器只"看"与自身相关的 tile + _VQ_KEEP = { + 1: frozenset({0, 1}), + 2: frozenset({0, 1, 2, 4, 5}), + 3: None, # 完整地图 + } + + def __init__( + self, + data_path: str, + stage: int, + subset_weights: tuple = (0.5, 0.2, 0.2, 0.1), + wall_mask_ratio: float = 0.3, + room_thresholds: tuple = None, + branch_thresholds: tuple = None, + ): + """ + Args: + data_path: JSON 数据文件路径 + stage: 生成阶段 1/2/3 + subset_weights: 子集 (A, B, C, D) 的采样权重,自动归一化 + wall_mask_ratio: Subset C 中额外随机 mask 的 wall 比例上限 + room_thresholds: 等频分箱阈值(None 时自动计算) + branch_thresholds: 等频分箱阈值(None 时自动计算) + """ + assert stage in (1, 2, 3), f"stage 必须是 1/2/3,收到 {stage}" + self.stage = stage + self.data = load_data(data_path) + self.wall_mask_ratio = wall_mask_ratio + + total_w = sum(subset_weights) + normalized = [x / total_w for x in subset_weights] + self.subset_cumw = [sum(normalized[:i + 1]) for i in range(len(normalized))] + + room_counts = [item['roomCount'] for item in self.data] + branch_counts = [item['highDegBranchCount'] for item in self.data] + + if room_thresholds is None: + n = len(room_counts) + rs = sorted(room_counts) + bs = sorted(branch_counts) + th1_r, th2_r = rs[n // 3], rs[2 * n // 3] + th1_b, th2_b = bs[n // 3], bs[2 * n // 3] + if th1_r == th2_r: th2_r = th1_r + 1 + if th1_b == th2_b: th2_b = th1_b + 1 + self.room_th = (th1_r, th2_r) + self.branch_th = (th1_b, th2_b) + else: + self.room_th = room_thresholds + self.branch_th = branch_thresholds + + def to_level(v, th): + return 0 if v < th[0] else (1 if v < th[1] else 2) + + for item in self.data: + item['roomCountLevel'] = to_level(item['roomCount'], self.room_th) + item['branchLevel'] = to_level(item['highDegBranchCount'], self.branch_th) + + def __len__(self): + return len(self.data) + + # ------------------------------------------------------------------ + # 掩码辅助(与 GinkaVQDataset 相同逻辑) + # ------------------------------------------------------------------ + @staticmethod + def _sample_mask_ratio(min_r=0.05, max_r=1.0) -> float: + r = np.random.beta(2, 2) + return min_r + (max_r - min_r) * r + + @staticmethod + def _random_mask(h: int, w: int) -> np.ndarray: + ratio = GinkaStageDataset._sample_mask_ratio() + total = h * w + idx = np.random.choice(total, int(total * ratio), replace=False) + mask = np.zeros(total, dtype=bool) + mask[idx] = True + return mask + + @staticmethod + def _block_mask(h: int, w: int) -> np.ndarray: + ratio = GinkaStageDataset._sample_mask_ratio() + max_block = max(2, min(h, w) // 2) + target = int(h * w * ratio) + mask = np.zeros((h, w), dtype=bool) + while mask.sum() < target: + bh = np.random.randint(2, max_block + 1) + bw = np.random.randint(2, max_block + 1) + x = np.random.randint(0, max(1, h - bh + 1)) + y = np.random.randint(0, max(1, w - bw + 1)) + mask[x:x + bh, y:y + bw] = True + return mask.reshape(-1) + + def _std_mask(self, h: int, w: int) -> np.ndarray: + return self._random_mask(h, w) if random.random() < 0.5 else self._block_mask(h, w) + + # ------------------------------------------------------------------ + # 子集选择 + # ------------------------------------------------------------------ + def _choose_subset(self) -> str: + r = random.random() + if r < self.subset_cumw[0]: return 'A' + if r < self.subset_cumw[1]: return 'B' + if r < self.subset_cumw[2]: return 'C' + return 'D' + + # ------------------------------------------------------------------ + # 阶段一:结构骨架(floor + wall) + # ------------------------------------------------------------------ + def _make_stage1(self, raw_flat: np.ndarray, subset: str): + """ + 阶段一:预测 floor/wall,所有非 floor/wall tile 在目标中重映射为 floor。 + 子集决定向模型提供多少 wall 作为上下文条件。 + """ + H = W = 13 + + # 目标:非 floor/wall → floor + target = raw_flat.copy() + target[~np.isin(target, [self.FLOOR, self.WALL])] = self.FLOOR + + inp = target.copy() + + if subset == 'A': + # 标准随机 mask:随机遮盖部分 floor/wall + mask = self._std_mask(H, W) + inp[mask] = self.MASK_ID + + elif subset == 'B': + # 保留全部 wall,MASK floor + inp[inp == self.FLOOR] = self.MASK_ID + + elif subset == 'C': + # 随机保留部分 wall,MASK 其余(含全部 floor) + inp[inp == self.FLOOR] = self.MASK_ID + wall_idx = np.where(inp == self.WALL)[0] + if len(wall_idx) > 0: + ratio = random.random() * self.wall_mask_ratio + n = max(1, int(len(wall_idx) * ratio)) + chosen = np.random.choice(wall_idx, n, replace=False) + inp[chosen] = self.MASK_ID + + else: # D:与 B 相同(阶段一无 entrance 维度) + inp[inp == self.FLOOR] = self.MASK_ID + + loss_mask = (inp == self.MASK_ID) + return inp, target, loss_mask + + # ------------------------------------------------------------------ + # 阶段二:功能元素(door + monster + entrance) + # ------------------------------------------------------------------ + def _make_stage2(self, raw_flat: np.ndarray, subset: str): + """ + 阶段二:以 floor/wall 为上下文,预测 door/monster/entrance。 + resource 在输入与目标中均视为 floor(阶段二不负责资源)。 + 子集决定 wall 上下文的完整程度与 door/monster/entrance 的掩码方式。 + """ + # 目标:resource → floor + target = raw_flat.copy() + target[target == self.RESOURCE] = self.FLOOR + + # 基础输入:resource → floor,功能元素先保留,再按子集处理 + inp = raw_flat.copy() + inp[inp == self.RESOURCE] = self.FLOOR + + if subset == 'A': + # 随机遮盖部分 door/monster/entrance(部分上下文补全) + func_idx = np.where(np.isin(inp, [self.DOOR, self.MONSTER, self.ENTRANCE]))[0] + if len(func_idx) > 0: + ratio = random.random() * 0.8 + 0.2 # 20%~100% + n = max(1, int(len(func_idx) * ratio)) + chosen = np.random.choice(func_idx, n, replace=False) + inp[chosen] = self.MASK_ID + else: + # B/C/D:全部 door/monster/entrance → MASK + inp[np.isin(inp, [self.DOOR, self.MONSTER, self.ENTRANCE])] = self.MASK_ID + + if subset == 'C': + # 额外随机 mask 部分 wall(降低 wall 上下文质量) + wall_idx = np.where(inp == self.WALL)[0] + if len(wall_idx) > 0: + ratio = random.random() * self.wall_mask_ratio + n = max(1, int(len(wall_idx) * ratio)) + chosen = np.random.choice(wall_idx, n, replace=False) + inp[chosen] = self.MASK_ID + + # loss_mask:阶段二只对 door/monster/entrance 原始位置计算损失, + # 不对被额外 mask 的 wall 位置计算(它们在 target 中已知为 wall) + loss_mask = np.isin(raw_flat, [self.DOOR, self.MONSTER, self.ENTRANCE]) + return inp, target, loss_mask + + # ------------------------------------------------------------------ + # 阶段三:资源放置(resource) + # ------------------------------------------------------------------ + def _make_stage3(self, raw_flat: np.ndarray, subset: str): + """ + 阶段三:以完整骨架为上下文,预测 resource 位置。 + 所有 resource 位置在输入中替换为 MASK。 + 子集 A 随机保留部分 resource 作为上下文(部分补全训练), + 其余子集始终 MASK 全部 resource。 + """ + target = raw_flat.copy() + inp = raw_flat.copy() + + if subset == 'A': + # 随机遮盖部分 resource(部分上下文补全) + res_idx = np.where(inp == self.RESOURCE)[0] + if len(res_idx) > 0: + ratio = random.random() * 0.8 + 0.2 # 20%~100% + n = max(1, int(len(res_idx) * ratio)) + chosen = np.random.choice(res_idx, n, replace=False) + inp[chosen] = self.MASK_ID + else: + pass # 无 resource 时无需处理 + else: + # B/C/D:全部 resource → MASK + inp[inp == self.RESOURCE] = self.MASK_ID + + loss_mask = (inp == self.MASK_ID) + return inp, target, loss_mask + + # ------------------------------------------------------------------ + # __getitem__ + # ------------------------------------------------------------------ + def _augment(self, arr: np.ndarray) -> np.ndarray: + if np.random.rand() > 0.5: + k = np.random.randint(1, 4) + arr = np.rot90(arr, k).copy() + if np.random.rand() > 0.5: + arr = np.fliplr(arr).copy() + if np.random.rand() > 0.5: + arr = np.flipud(arr).copy() + return arr + + def __getitem__(self, idx): + item = self.data[idx] + + raw_np = self._augment(np.array(item['map'], dtype=np.int64)) # [H, W] + raw_flat = raw_np.reshape(-1) # [H*W] + subset = self._choose_subset() + + if self.stage == 1: + stage_input_np, target_np, loss_mask_np = self._make_stage1(raw_flat, subset) + elif self.stage == 2: + stage_input_np, target_np, loss_mask_np = self._make_stage2(raw_flat, subset) + else: + stage_input_np, target_np, loss_mask_np = self._make_stage3(raw_flat, subset) + + # 若 loss_mask 全为 False(如地图中无 resource 时的 stage3), + # 退回为全图损失,避免 NaN + if not loss_mask_np.any(): + loss_mask_np = np.ones_like(loss_mask_np) + + # VQ 切片:当前阶段编码器的输入(仅保留相关 tile) + raw_t = torch.LongTensor(raw_flat) + vq_keep = self._VQ_KEEP[self.stage] + if vq_keep is None: + vq_slice = raw_t.clone() + else: + vq_slice = make_slice(raw_t, vq_keep) + + # 结构标签 + sym_h, sym_v, sym_c = _compute_symmetry(raw_np) + cond_sym = sym_h * 4 + sym_v * 2 + sym_c + cond_room = item['roomCountLevel'] + cond_branch = item['branchLevel'] + cond_outer = item['outerWall'] + struct_cond = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer]) + + return { + "raw_map": raw_t, + "vq_slice": vq_slice, + "stage_input": torch.LongTensor(stage_input_np), + "target_map": torch.LongTensor(target_np), + "loss_mask": torch.BoolTensor(loss_mask_np), + "subset": subset, + "struct_cond": struct_cond, + } + + if __name__ == "__main__": import os data_path = os.path.join(os.path.dirname(__file__), '..', 'data', 'ginka-dataset.json') diff --git a/ginka/maskGIT/model.py b/ginka/maskGIT/model.py index a0806d3..83c9aba 100644 --- a/ginka/maskGIT/model.py +++ b/ginka/maskGIT/model.py @@ -57,8 +57,12 @@ class GinkaMaskGIT(nn.Module): # z 投影:将 VQ 码字从 d_z 维映射到 d_model 维,供 cross-attention 使用 self.z_proj = nn.Sequential( - nn.Linear(d_z, d_model), - nn.LayerNorm(d_model), + nn.Linear(d_z, d_model * 2), + nn.LayerNorm(d_model * 2), + nn.GELU(), + + nn.Linear(d_model * 2, d_model), + nn.LayerNorm(d_model) ) # 结构标签嵌入(编码到 d_z 维度) @@ -69,8 +73,12 @@ class GinkaMaskGIT(nn.Module): self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z) self.struct_proj = nn.Sequential( - nn.Linear(d_z, d_model), - nn.LayerNorm(d_model), + nn.Linear(d_z, d_model * 2), + nn.LayerNorm(d_model * 2), + nn.GELU(), + + nn.Linear(d_model * 2, d_model), + nn.LayerNorm(d_model) ) # Transformer:encoder 做 map token 自注意力,decoder 做与 z 的 cross-attention diff --git a/ginka/train_stage.py b/ginka/train_stage.py new file mode 100644 index 0000000..b8da7a5 --- /dev/null +++ b/ginka/train_stage.py @@ -0,0 +1,654 @@ +""" +三阶段级联训练脚本:各阶段独立训练,使用 GinkaStageDataset。 + +总损失 = L_CE(只对本阶段负责的 tile 位置计算)+ beta * L_commit + gamma * L_entropy + +各阶段分工: + stage=1 结构骨架:floor(0) + wall(1) + stage=2 功能元素:door(2) + monster(4) + entrance(5) + stage=3 资源放置:resource(3) + +用法示例: + python -m ginka.train_stage --stage 1 + python -m ginka.train_stage --stage 2 + python -m ginka.train_stage --stage 3 + python -m ginka.train_stage --stage 1 --resume True --state result/stage1/stage1-10.pth + python -m ginka.train_stage --stage 2 --pretrain_vq result/joint/joint-50.pth +""" + +import argparse +import math +import os +import sys +import random +from datetime import datetime + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +import torch.optim as optim +from tqdm import tqdm +from torch.utils.data import DataLoader + +from .vqvae.model import GinkaVQVAE +from .maskGIT.model import GinkaMaskGIT +from .dataset import GinkaStageDataset +from shared.image import matrix_to_image_cv + +# --------------------------------------------------------------------------- +# 各阶段配置 +# --------------------------------------------------------------------------- + +# 共用 VQ-VAE 超参 +VQ_L = 2 # 码字序列长度 +VQ_K = 8 # codebook 大小 +VQ_D_Z = 64 # 码字维度 +VQ_BETA = 0.5 # commit loss 权重 +VQ_GAMMA = 0.0 # entropy loss 权重 +VQ_LAYERS = 3 +VQ_DIM_FF = 512 +VQ_D_MODEL = 64 +VQ_NHEAD = 8 + +# 各阶段 MaskGIT 超参(按任务复杂度差异化配置) +STAGE_MG_CONFIGS = { + 1: dict(d_model=256, nhead=8, num_layers=6, dim_ff=2048), # 结构骨架,最重要 + 2: dict(d_model=192, nhead=8, num_layers=4, dim_ff=1024), # 功能元素 + 3: dict(d_model=128, nhead=8, num_layers=3, dim_ff=512), # 资源放置,最简单 +} + +# 各阶段监控的 tile 集合(用于分类别召回率统计) +STAGE_TILE_SETS = { + 1: {0: "floor", 1: "wall"}, + 2: {2: "door", 4: "monster", 5: "entrance"}, + 3: {3: "resource"}, +} + +# 各阶段损失权重(可单独调节 CE 与 VQ 损失的平衡) +# stage3 的 resource 极稀疏,大幅上调 ce_weight 以补偿类别不均衡 +STAGE_LOSS_CONFIG = { + 1: dict(ce_weight=1.0, vq_weight=1.0), # 结构骨架,标准权重 + 2: dict(ce_weight=1.5, vq_weight=0.5), # 功能元素较稀疏,上调 CE + 3: dict(ce_weight=3.0, vq_weight=0.5), # resource 极稀疏,显著上调 CE +} + +NUM_CLASSES = 7 +MASK_TOKEN = 6 +MAP_SIZE = 13 * 13 +MAP_H = MAP_W = 13 +FOCAL_GAMMA = 2.0 +GENERATE_STEP = 18 +BATCH_SIZE = 64 +WALL_MASK_RATIO = 0.8 + +MG_Z_DROPOUT = 0.1 +MG_STRUCT_DROPOUT = 0.1 + +SUBSET_WEIGHTS = (0.5, 0.2, 0.2, 0.1) + +device = torch.device( + "cuda:1" if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() + else "cpu" +) + +disable_tqdm = not sys.stdout.isatty() + +# --------------------------------------------------------------------------- +# 参数解析 +# --------------------------------------------------------------------------- + +def _str2bool(v): + if isinstance(v, bool): return v + if v.lower() in ('true', '1', 'yes'): return True + if v.lower() in ('false', '0', 'no'): return False + raise argparse.ArgumentTypeError(f"布尔值应为 True/False,收到: {v!r}") + + +def parse_arguments(): + parser = argparse.ArgumentParser(description="三阶段级联训练") + parser.add_argument("--stage", type=int, required=True, choices=[1, 2, 3]) + parser.add_argument("--resume", type=_str2bool, default=False) + parser.add_argument( + "--state", type=str, default="", + help="续训时加载的检查点路径(自动推断 stage{N}/stage{N}-*.pth)", + ) + parser.add_argument("--train", type=str, default="ginka-dataset.json") + parser.add_argument("--validate", type=str, default="ginka-eval.json") + parser.add_argument("--epochs", type=int, default=100) + parser.add_argument("--checkpoint", type=int, default=5) + parser.add_argument("--load_optim", type=_str2bool, default=True) + parser.add_argument( + "--freeze_vq", type=_str2bool, default=False, + help="冻结 VQ 编码器,只训练 MaskGIT(适合加载预训练编码器后热身)", + ) + parser.add_argument( + "--pretrain_vq", type=str, default="", + help="从 train_vq.py 的联合训练检查点中导入对应通道的 VQ 编码器权重", + ) + return parser.parse_args() + +# --------------------------------------------------------------------------- +# Focal Loss(与 train_vq.py 一致) +# --------------------------------------------------------------------------- + +def focal_loss(logits, targets, gamma=FOCAL_GAMMA, reduction='none'): + ce = F.cross_entropy(logits, targets, reduction='none') + pt = torch.exp(-ce) + fl = (1.0 - pt) ** gamma * ce + if reduction == 'mean': return fl.mean() + if reduction == 'sum': return fl.sum() + return fl + + +def masked_focal_loss(logits, targets, loss_mask, gamma=FOCAL_GAMMA): + """ + 只对 loss_mask 为 True 的位置计算 focal loss 均值。 + + Args: + logits: [B, C, H*W] + targets: [B, H*W] + loss_mask: [B, H*W] bool + """ + per_token = focal_loss(logits, targets, gamma, reduction='none') # [B, H*W] + selected = per_token[loss_mask] + if selected.numel() == 0: + return per_token.mean() + return selected.mean() + +# --------------------------------------------------------------------------- +# MaskGIT 推理(cosine schedule) +# --------------------------------------------------------------------------- + +@torch.no_grad() +def maskgit_generate( + model_mg: GinkaMaskGIT, + z: torch.Tensor, + steps: int = GENERATE_STEP, + init_map: torch.Tensor = None, + struct_cond: torch.Tensor = None, +) -> torch.Tensor: + """ + 迭代生成地图(cosine schedule unmasking)。 + + Args: + init_map: 可选初始地图;非 MASK 位置在生成中保持不变。 + + Returns: + [B, MAP_SIZE] LongTensor + """ + B = z.shape[0] + map_seq = ( + torch.full((B, MAP_SIZE), MASK_TOKEN, device=device) + if init_map is None else init_map.clone().to(device) + ) + + generatable = (map_seq == MASK_TOKEN) + + for step in range(steps): + if not generatable.any(): + break + + logits = model_mg(map_seq, z, struct_cond=struct_cond) # [B, S, C] + probs = F.softmax(logits, dim=-1) + dist = torch.distributions.Categorical(probs) + sampled = dist.sample() + confidences = torch.gather(probs, -1, sampled.unsqueeze(-1)).squeeze(-1) + confidences = confidences.masked_fill(~generatable, float('inf')) + + ratio = math.cos(((step + 1) / steps) * math.pi / 2) + new_map = map_seq.clone() + + for b in range(B): + n_gen = int(generatable[b].sum().item()) + n_keep = int(ratio * n_gen) + if n_keep > 0: + _, keep_idx = torch.topk(confidences[b], k=n_keep, largest=False) + pred_b = sampled[b].clone() + pred_b[keep_idx] = MASK_TOKEN + new_map[b] = torch.where(generatable[b], pred_b, map_seq[b]) + else: + new_map[b] = torch.where(generatable[b], sampled[b], map_seq[b]) + + map_seq = new_map + + return map_seq + +# --------------------------------------------------------------------------- +# 可视化工具(与 train_vq.py 保持一致) +# --------------------------------------------------------------------------- + +def make_map_image(map_flat, tile_dict): + arr = map_flat.cpu().numpy().reshape(MAP_H, MAP_W) + return matrix_to_image_cv(arr, tile_dict) + + +def hstack_images(imgs, gap=4, color=(255, 255, 255)): + max_h = max(img.shape[0] for img in imgs) + + def _pad(img): + dh = max_h - img.shape[0] + return img if dh == 0 else np.concatenate( + [img, np.full((dh, img.shape[1], 3), color, dtype=np.uint8)], axis=0) + + vline = np.full((max_h, gap, 3), color, dtype=np.uint8) + result = _pad(imgs[0]) + for img in imgs[1:]: + result = np.concatenate([result, vline, _pad(img)], axis=1) + return result + + +def grid_images(imgs, gap=4, bg=(255, 255, 255)): + n = len(imgs) + if n == 0: return np.zeros((1, 1, 3), dtype=np.uint8) + if n == 1: return imgs[0] + mid = math.ceil(n / 2) + top = hstack_images(imgs[:mid], gap, bg) + bot_imgs = imgs[mid:] + if not bot_imgs: return top + bot = hstack_images(bot_imgs, gap, bg) + tw, bw = top.shape[1], bot.shape[1] + if tw > bw: + bot = np.concatenate( + [bot, np.full((bot.shape[0], tw - bw, 3), bg, dtype=np.uint8)], axis=1) + elif bw > tw: + top = np.concatenate( + [top, np.full((top.shape[0], bw - tw, 3), bg, dtype=np.uint8)], axis=1) + hline = np.full((gap, top.shape[1], 3), bg, dtype=np.uint8) + return np.concatenate([top, hline, bot], axis=0) + + +def label_image(img, text, font_scale=0.45): + bar = np.full((16, img.shape[1], 3), (40, 40, 40), dtype=np.uint8) + cv2.putText( + bar, text, (2, 13), cv2.FONT_HERSHEY_SIMPLEX, + font_scale, (200, 200, 200), 1, cv2.LINE_AA, + ) + return np.concatenate([bar, img], axis=0) + + +def make_random_struct_cond(): + from .maskGIT.model import SYM_VOCAB, ROOM_VOCAB, BRANCH_VOCAB, OUTER_VOCAB + return torch.tensor([[ + random.randint(0, SYM_VOCAB - 2), + random.randint(0, ROOM_VOCAB - 2), + random.randint(0, BRANCH_VOCAB - 2), + random.randint(0, OUTER_VOCAB - 2), + ]], dtype=torch.long, device=device) + +# --------------------------------------------------------------------------- +# 按阶段构造推理初始地图 +# --------------------------------------------------------------------------- + +def make_stage_init(stage: int, context_map: torch.Tensor) -> torch.Tensor: + """ + 根据阶段构造 MaskGIT 的推理初始地图。 + + Stage 1: 全 MASK(或保留稀疏 wall 种子) + Stage 2: 保留 floor/wall 上下文,其余 → MASK + Stage 3: 保留完整上下文(floor/wall/door/monster/entrance),resource → MASK + """ + init = context_map.clone() + + if stage == 1: + # 全 MASK(不依赖上下文地图) + init = torch.full_like(init, MASK_TOKEN) + + elif stage == 2: + # 保留 floor/wall,其余 → MASK + mask = ~torch.isin(init, torch.tensor([0, 1], device=init.device)) + init[mask] = MASK_TOKEN + + else: # stage == 3 + # 保留非 resource,resource → MASK + init[init == 3] = MASK_TOKEN + + return init + + +def make_random_wall_seed(ratio_min=0.02, ratio_max=0.08): + ratio = random.uniform(ratio_min, ratio_max) + n_wall = max(2, int(MAP_SIZE * ratio)) + seed = torch.full((1, MAP_SIZE), MASK_TOKEN, dtype=torch.long, device=device) + idx = torch.randperm(MAP_SIZE)[:n_wall] + seed[0, idx] = 1 + return seed + +# --------------------------------------------------------------------------- +# 验证函数 +# --------------------------------------------------------------------------- + +@torch.no_grad() +def validate( + stage: int, + enc: GinkaVQVAE, + model_mg: GinkaMaskGIT, + dataloader_val: DataLoader, + tile_dict: dict, + epoch: int, + n_rand: int = 3, +): + enc.eval() + model_mg.eval() + + epoch_dir = f"result/stage{stage}_img/e{epoch:04d}" + os.makedirs(epoch_dir, exist_ok=True) + + val_loss_total = 0.0 + val_steps = 0 + captured = {s: None for s in ('A', 'B', 'C', 'D')} + + for batch in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm): + raw_map = batch["raw_map"].to(device) + vq_slice = batch["vq_slice"].to(device) + stage_input = batch["stage_input"].to(device) + target_map = batch["target_map"].to(device) + loss_mask = batch["loss_mask"].to(device) + struct_cond = batch["struct_cond"].to(device) + subsets = batch["subset"] + + z_q, _, _, vq_loss, _, _ = enc(vq_slice) + logits = model_mg(stage_input, z_q, struct_cond=struct_cond) + + ce = masked_focal_loss(logits.permute(0, 2, 1), target_map, loss_mask) + val_loss_total += (ce + vq_loss).item() + val_steps += 1 + + for i in range(raw_map.shape[0]): + s = subsets[i] + if captured[s] is None: + captured[s] = { + "raw": raw_map[i:i+1].clone(), + "stage_input": stage_input[i:i+1].clone(), + "z_q": z_q[i:i+1].clone(), + "struct_cond": struct_cond[i:i+1].clone(), + } + + if all(v is not None for v in captured.values()): + break + + # ---- 可视化:每个子集一张图 ---------------------------------------- + for sub, cap in captured.items(): + if cap is None: + continue + + raw_img = label_image(make_map_image(cap["raw"][0], tile_dict), "GT") + inp_img = label_image(make_map_image(cap["stage_input"][0], tile_dict), f"stage{stage} input") + + # 真实 z 的迭代生成 + init = make_stage_init(stage, cap["stage_input"][0].unsqueeze(0)) + gen = maskgit_generate( + model_mg, cap["z_q"], + init_map=init, struct_cond=cap["struct_cond"], + ) + gen_img = label_image(make_map_image(gen[0], tile_dict), "z_real gen") + + # 随机 z 的生成 + rand_imgs = [] + for i in range(n_rand): + z_r = enc.sample(1, device) + sc_r = make_random_struct_cond() + init2 = make_stage_init(stage, cap["raw"][0].unsqueeze(0)) + gen_r = maskgit_generate(model_mg, z_r, init_map=init2, struct_cond=sc_r) + rand_imgs.append(label_image(make_map_image(gen_r[0], tile_dict), f"z_rand_{i+1}")) + + row = [raw_img, inp_img, gen_img] + rand_imgs + cv2.imwrite(f"{epoch_dir}/subset_{sub}.png", grid_images(row)) + + # ---- 场景:完全自主生成 ----------------------------------------------- + # stage1:从随机稀疏墙壁种子出发(完全不依赖 GT) + # stage2:以验证集中采样的 floor/wall 结构为上下文,随机 z₂(模拟级联推理) + # stage3:以验证集中采样的完整功能地图为上下文,随机 z₃(模拟级联推理) + context_pool = [cap["raw"][0] for cap in captured.values() if cap is not None] + + rand_free = [] + for i in range(n_rand + 1): + z_r = enc.sample(1, device) + sc_r = make_random_struct_cond() + + if stage == 1: + # 稀疏 wall 种子作为提示,模型自主补全 floor/wall + init = make_random_wall_seed() + else: + # 从验证集上下文池中轮流取一张图作为前序阶段的输出 + ctx = context_pool[i % len(context_pool)].unsqueeze(0) + # make_stage_init 会自动将本阶段负责的 tile 位置替换为 MASK + init = make_stage_init(stage, ctx) + + gen = maskgit_generate(model_mg, z_r, init_map=init, struct_cond=sc_r) + rand_free.append(label_image(make_map_image(gen[0], tile_dict), f"free_{i+1}")) + cv2.imwrite(f"{epoch_dir}/scene_free_random.png", grid_images(rand_free)) + + return val_loss_total / max(val_steps, 1) + +# --------------------------------------------------------------------------- +# 主训练函数 +# --------------------------------------------------------------------------- + +def train(): + print(f"Using device: {device}") + args = parse_arguments() + stage = args.stage + + result_dir = f"result/stage{stage}" + result_img_dir = f"result/stage{stage}_img" + os.makedirs(result_dir, exist_ok=True) + os.makedirs(result_img_dir, exist_ok=True) + + # ---- VQ 编码器(单路)---- + mg_cfg = STAGE_MG_CONFIGS[stage] + enc = GinkaVQVAE( + num_classes=NUM_CLASSES, + L=VQ_L, + K=VQ_K, + d_z=VQ_D_Z, + d_model=VQ_D_MODEL, + nhead=VQ_NHEAD, + num_layers=VQ_LAYERS, + dim_ff=VQ_DIM_FF, + map_size=MAP_SIZE, + beta=VQ_BETA, + gamma=VQ_GAMMA, + ).to(device) + + model_mg = GinkaMaskGIT( + num_classes=NUM_CLASSES, + d_model=mg_cfg["d_model"], + d_z=VQ_D_Z, + dim_ff=mg_cfg["dim_ff"], + nhead=mg_cfg["nhead"], + num_layers=mg_cfg["num_layers"], + map_size=MAP_SIZE, + z_dropout=MG_Z_DROPOUT, + struct_dropout=MG_STRUCT_DROPOUT, + ).to(device) + + enc_params = sum(p.numel() for p in enc.parameters()) + mg_params = sum(p.numel() for p in model_mg.parameters()) + print(f"[Stage {stage}] VQ Encoder 参数量: {enc_params:,} ({enc_params/1e6:.3f}M)") + print(f"[Stage {stage}] MaskGIT 参数量: {mg_params:,} ({mg_params/1e6:.3f}M)") + + # ---- 数据集 ---- + dataset_train = GinkaStageDataset( + args.train, + stage=stage, + subset_weights=SUBSET_WEIGHTS, + wall_mask_ratio=WALL_MASK_RATIO, + ) + dataset_val = GinkaStageDataset( + args.validate, + stage=stage, + subset_weights=SUBSET_WEIGHTS, + room_thresholds=dataset_train.room_th, + branch_thresholds=dataset_train.branch_th, + wall_mask_ratio=WALL_MASK_RATIO, + ) + dataloader_train = DataLoader( + dataset_train, + batch_size=BATCH_SIZE, + shuffle=True, + num_workers=0, + pin_memory=(device.type == "cuda"), + ) + dataloader_val = DataLoader( + dataset_val, + batch_size=8, + shuffle=True, + num_workers=0, + ) + + # ---- 优化器 ---- + all_params = list(enc.parameters()) + list(model_mg.parameters()) + optimizer = optim.AdamW(all_params, lr=2e-4, weight_decay=1e-2) + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=args.epochs, eta_min=1e-6, + ) + + # ---- 权重加载 ---- + start_epoch = 0 + + if args.pretrain_vq: + # 从 train_vq.py 的联合训练检查点加载对应通道的 VQ 编码器 + ckpt = torch.load(args.pretrain_vq, map_location=device) + enc_key = f"enc{stage}" + if enc_key in ckpt: + enc.load_state_dict(ckpt[enc_key], strict=False) + print(f"已从 {args.pretrain_vq} 加载 {enc_key} 权重。") + else: + print(f"警告:检查点中未找到 {enc_key},跳过权重加载。") + + if args.resume: + state_path = args.state or f"{result_dir}/stage{stage}-latest.pth" + ckpt = torch.load(state_path, map_location=device) + enc.load_state_dict(ckpt["enc"], strict=False) + model_mg.load_state_dict(ckpt["mg_state"], strict=False) + if args.load_optim and ckpt.get("optim_state") is not None: + optimizer.load_state_dict(ckpt["optim_state"]) + start_epoch = ckpt.get("epoch", 0) + print(f"从 epoch {start_epoch} 接续训练。") + + # ---- tile 贴图 ---- + tile_dict = {} + for f in os.listdir("tiles"): + name = os.path.splitext(f)[0] + img = cv2.imread(f"tiles/{f}", cv2.IMREAD_UNCHANGED) + if img is not None: + tile_dict[name] = img + + # ---- 冻结 VQ 编码器(可选)---- + if args.freeze_vq: + for p in enc.parameters(): + p.requires_grad_(False) + print(f"[Stage {stage}] VQ 编码器已冻结。") + + # ---- 训练循环 ---- + for epoch in tqdm( + range(start_epoch, start_epoch + args.epochs), + desc=f"Stage{stage} Training", + disable=disable_tqdm, + ): + enc.train() + model_mg.train() + + loss_total = 0.0 + ce_total = 0.0 + vq_loss_total = 0.0 + subset_stats = {'A': 0, 'B': 0, 'C': 0, 'D': 0} + + # 按 tile 统计召回率(用于监控各类 tile 的预测准确性) + tile_correct = {tid: 0 for tid in STAGE_TILE_SETS[stage]} + tile_total = {tid: 0 for tid in STAGE_TILE_SETS[stage]} + + for batch in tqdm( + dataloader_train, + leave=False, + desc="Epoch Progress", + disable=disable_tqdm, + ): + raw_map = batch["raw_map"].to(device) + vq_slice = batch["vq_slice"].to(device) + stage_input = batch["stage_input"].to(device) + target_map = batch["target_map"].to(device) + loss_mask = batch["loss_mask"].to(device) + struct_cond = batch["struct_cond"].to(device) + + for s in batch["subset"]: + subset_stats[s] = subset_stats.get(s, 0) + 1 + + # ---- 前向传播 ---- + z_q, _, _, vq_loss, commit_loss, entropy_loss = enc(vq_slice) + logits = model_mg(stage_input, z_q, struct_cond=struct_cond) # [B, S, C] + + # ---- 仅对本阶段 tile 位置计算 focal loss ---- + ce_loss = masked_focal_loss(logits.permute(0, 2, 1), target_map, loss_mask) + loss_cfg = STAGE_LOSS_CONFIG[stage] + loss = loss_cfg["ce_weight"] * ce_loss + loss_cfg["vq_weight"] * vq_loss + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0) + optimizer.step() + + loss_total += loss.detach().item() + ce_total += ce_loss.detach().item() + vq_loss_total += vq_loss.detach().item() + + # ---- 分 tile 召回率统计 ---- + with torch.no_grad(): + preds = logits.argmax(dim=-1) # [B, S] + for tid in STAGE_TILE_SETS[stage]: + gt_mask = (target_map == tid) & loss_mask + tile_total[tid] += gt_mask.sum().item() + tile_correct[tid] += (preds[gt_mask] == tid).sum().item() + + scheduler.step() + + n = len(dataloader_train) + recall_str = " ".join( + f"{STAGE_TILE_SETS[stage][tid]}={tile_correct[tid]/(tile_total[tid]+1e-6):.2%}" + for tid in STAGE_TILE_SETS[stage] + ) + tqdm.write( + f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + f"Epoch {epoch + 1:4d} | " + f"Loss {loss_total/n:.5f} " + f"Focal {ce_total/n:.5f} " + f"VQ {vq_loss_total/n:.5f} | " + f"Recall: {recall_str} | " + f"LR {scheduler.get_last_lr()[0]:.6f} | " + f"Subsets {subset_stats}" + ) + + # ---- 检查点 + 验证 ---- + if (epoch + 1) % args.checkpoint == 0: + ckpt_path = f"{result_dir}/stage{stage}-{epoch + 1}.pth" + torch.save({ + "epoch": epoch + 1, + "stage": stage, + "enc": enc.state_dict(), + "mg_state": model_mg.state_dict(), + "optim_state": optimizer.state_dict(), + }, ckpt_path) + tqdm.write(f" 检查点已保存: {ckpt_path}") + + val_loss = validate(stage, enc, model_mg, dataloader_val, tile_dict, epoch + 1) + tqdm.write(f"[Validate] Epoch {epoch + 1:4d} | Val Loss {val_loss:.5f}") + + enc.train() + model_mg.train() + + # ---- 最终存档 ---- + torch.save({ + "epoch": start_epoch + args.epochs, + "stage": stage, + "enc": enc.state_dict(), + "mg_state": model_mg.state_dict(), + }, f"{result_dir}/stage{stage}_final.pth") + print(f"[Stage {stage}] 训练结束。") + + +# --------------------------------------------------------------------------- +if __name__ == "__main__": + torch.set_num_threads(4) + train()