ginka-generator/docs/z-improvement-design.md
unanmed a69403d6bf feat: vq 预训练
Co-authored-by: Copilot <copilot@github.com>
2026-04-28 16:31:53 +08:00

296 lines
17 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# z 控制力改进方案设计文档
## 背景
当前 VQ-VAE + MaskGIT 联合训练方案中VQ-VAE 编码器将完整地图压缩为离散隐变量 zz 经 Cross-Attention 注入 MaskGIT。实践中发现 z 对生成结果的控制力仍然不足——即使给定相同的 z生成结果与原地图的结构相似度也偏低。
本文档整理以下两个可行改进方向,并详细分析各自的技术路径与取舍:
| 方案 | 核心思路 | 状态 |
| ------ | -------------------------------------------------- | -------- |
| 方案 A | 重建一致性约束:将生成结果回送编码器,令 z 闭环 | 已实施 |
| 方案 B | 多路分拆编码:将地图按层次结构分拆为多部分分别编码 | 待细化 |
| 方案 C | 多阶段生成:先墙壁,再门怪,最后资源 | 后续计划 |
| 方案 D | VQ 编码器预训练:先单独训练编码器学会重建,再联合 | 待细化 |
---
## 方案 A重建一致性约束z 闭环)
### 核心思路
在联合训练过程中,对 MaskGIT 生成的地图再次走一遍编码器,得到 z_pred然后约束
$$\mathcal{L}_{consist} = \| z - z_{pred} \|^2$$
如果 z_pred 能与原始 z 对齐,说明模型学到了"z 控制 → 生成结果 → 编码 → 同一个 z"的完整回路,从而强化 z 对生成内容的控制力。
### 核心难点:离散地图的梯度传递
MaskGIT 输出的是每个位置的分类 logits最终生成的地图是通过 argmax或 Gumbel 采样)得到的**离散整数矩阵**。离散采样不可微,梯度无法直接通过 argmax 回传到 MaskGIT。
以下是三种可行的梯度传递方案:
#### 方案 A-1软分布近似Soft Logits 送入编码器)
跳过 argmax将 MaskGIT 输出的原始 logits或 softmax 概率分布)直接送入编码器,替代离散 token 输入。
```
MaskGIT logits [B, H*W, V]
将每个位置的 softmax(logits) 与 tile embedding 加权求和
→ 得到 [B, H*W, d_emb] 的软嵌入序列
VQ-VAE 编码器(接收软嵌入输入)→ z_pred
L_consist = ||z - z_pred||²
梯度正常反传到 MaskGIT无离散断层
```
**优点**:完全可微,无需额外技巧。
**缺点**:编码器在训练时接收的是连续软嵌入,与推理时接收离散地图存在 train/inference gap编码器需要支持两种输入形式增加复杂度。
#### 方案 A-2Straight-Through EstimatorSTE
利用 STE 在前向使用 argmax在后向将梯度"直通"绕过 argmax 传给 logits
$$\text{forward: } \hat{y} = \text{onehot}(\arg\max(\text{logits}))$$
$$\text{backward: } \frac{\partial \mathcal{L}}{\partial \text{logits}} \approx \frac{\partial \mathcal{L}}{\partial \text{softmax(logits)}}$$
即在前向正常采样离散地图,在后向把 z_pred 的梯度传回 softmax(logits),再流入 MaskGIT。
**优点**:编码器输入与推理时一致(离散地图),无 train/inference gapVQ-VAE 自身已使用 STE风格统一。
**缺点**STE 是梯度近似在离散度很高logits 尖锐)时偏差较大;效果依赖于 logits 的平滑程度。
#### 方案 A-3REINFORCE / 策略梯度(作为备选)
将 z_consist 损失视为强化学习奖励:
$$\mathcal{L}_{RL} = -\mathbb{E}_{\hat{y} \sim \pi_\theta}[r(\hat{y})]$$
其中 $r(\hat{y}) = -\| z - \text{Enc}(\hat{y}) \|^2$(负的一致性误差作为奖励),$\pi_\theta$ 为 MaskGIT 的采样策略。
**优点**:不依赖梯度近似,理论上完全正确。
**缺点**:方差高,训练不稳定,需要引入 baseline 估计地图较大时169 个位置的联合分布)方差尤为严重。**不推荐作为主方案。**
### 推荐路径
**优先尝试方案 A-1软分布近似**,但输入编码器的不应是原始 softmax 概率——logits 的分布与真实离散地图差距较大,编码器难以理解。具体做法:对 logits 先施加温度 smoothing降低分布的尖锐程度再将平滑后的 softmax 概率与 tile embedding 矩阵做加权和,得到连续软嵌入序列后送入编码器。
若实验中发现训练与推理的 gap 仍然明显影响效果,再切换到方案 A-2STE
### 损失权重
$\mathcal{L}_{consist}$ 作为辅助损失,建议权重 $\lambda_{consist}$ 初始设为 0.1,观察 z 的编码熵codebook 使用分布)变化后再调整:
$$\mathcal{L}_{total} = \mathcal{L}_{CE} + \beta \cdot \mathcal{L}_{commit} + \gamma \cdot \mathcal{L}_{uniform} + \lambda \cdot \mathcal{L}_{consist}$$
### 训练细节
- 一致性约束在**所有子集A/B/C/D**中均应用。虽然子集 B/C/D 的输入地图不完整,但模型同样需要参考 z 来生成,对所有子集施加一致性约束有助于提高模型在指定条件下的生成能力。
- 训练时 MaskGIT 只进行单步解码,一致性约束在该单步解码结果上计算,无需多步展开,计算开销可控。
---
## 方案 B多路分拆编码Hierarchical z
### 核心思路
将完整地图按语义层次拆分为多个子地图,分别编码为较短的 z 向量,拼接后注入 MaskGIT。不同的 z 通道承载不同层次的语义信息,使 MaskGIT 能够从多个粒度获得控制信号。
### 分拆方案
| 通道 | 保留的图块类型 | 语义含义 |
| ------ | --------------------------------------------- | ------------------ |
| 通道 1 | 墙壁wall+ 楼梯/传送点stair | 地图骨架与层间连接 |
| 通道 2 | 墙壁 + 入口entry+ 怪物mob+ 门door | 关卡结构与交互元素 |
| 通道 3 | 完整地图(所有图块) | 全局风格与密度 |
未保留的图块位置填为"空白"tile = 0后送入对应编码器。
### 架构设计
```
完整地图 [H, W]
├──► 提取通道 1墙 + 楼梯)──► Encoder_1 ──► z_1 [L_1, d_z]
├──► 提取通道 2墙 + 入口 + 怪 + 门)──► Encoder_2 ──► z_2 [L_2, d_z]
└──► 原始地图 ──► Encoder_3 ──► z_3 [L_3, d_z]
z = Concat([z_1, z_2, z_3], dim=1) # [L_1 + L_2 + L_3, d_z]
MaskGIT Cross-Attentionz 作为 memory
```
### 码字数量分配
三个通道的信息量递增码字数量z 序列长度)应与信息量成比例。以总码字数 L 为预算:
| 通道 | 建议码字比例 | 说明 |
| ---- | ------------ | -------------------------------------- |
| z_1 | L/4 | 骨架信息精简,少量码字即可捕捉 |
| z_2 | L/4 | 关卡结构比骨架复杂,但仍是局部稀疏信息 |
| z_3 | L/2 | 全图风格需要更多表达维度 |
若当前方案中 L = 4即每通道 1 个码字),则分配为 1 + 1 + 2 = 4与原方案参数量相同。考虑到数据集规模较小约 10k 条),总码字数 L 不宜设置过大,避免 codebook 因样本不足而难以充分训练。
### 是否共享编码器权重
| 策略 | 优点 | 缺点 |
| ---------------------------- | ---------------------- | ------------------------------------ |
| 三路独立编码器 | 每个通道学习专属的表示 | 参数量增加约 3x |
| 共享底层权重Siamese 风格) | 参数高效,底层特征共用 | 需要精心设计通道标识符区分三路输入 |
| 完全共享(同一个编码器) | 参数最少 | 三路输入差异大,同一编码器可能欠拟合 |
**推荐****共享主干 Transformer + 三路独立输入头与输出头**。即三个通道共享所有 Transformer 层的权重但各自拥有独立的输入嵌入层input head将图块 id 映射为特征向量和独立的输出投影层output head量化前的线性投影。这样在参数量基本不变的情况下让不同通道在输入特征空间和量化前的表示上均有分化空间无需额外引入通道标识符 token。
### 推理时的使用方式
推理时三个通道均可独立随机采样:
- **完全随机生成**:三路 z 均从 codebook 随机采样;
- **骨架条件生成**:通道 1 的 z 由用户手绘的墙壁地图编码得到,通道 2、3 随机采样;
- **精确条件生成**:通道 3 的 z 由参考地图编码得到,通道 1、2 随机采样(风格迁移场景)。
这一机制使三个通道的解耦具有实际的用户交互价值,而不仅仅是训练侧的正则化手段。
### 与方案 A 的兼容性
方案 B 与方案 A 不冲突。在方案 B 的架构下,重建一致性约束可以单独作用于**通道 3全图编码器**,通道 1 和 2 的编码器由于输入为稀疏子地图,一致性约束意义相对较小。
---
---
## 方案 DVQ 编码器预训练
### 问题诊断
当前联合训练时VQ 编码器和 MaskGIT 从随机初始化开始同步优化。由于编码器尚未学到任何地图语义,早期 z 基本是随机噪声MaskGIT 无法从中获得有效的条件信号,两者的优化信号相互干扰,容易导致训练早期陷入局部最优或收敛缓慢。
解决思路:在联合训练开始前,先单独预训练 VQ 编码器,使其具备初步的地图语义理解能力,再以此为初始化启动联合训练。
### 核心思路
为 VQ-VAE 临时增加一个轻量解码头Decoder Head构成完整的自编码器以完整地图重建为目标进行预训练
$$\mathcal{L}_{pretrain} = \mathcal{L}_{CE}^{recon} + \beta \cdot \mathcal{L}_{commit} + \gamma \cdot \mathcal{L}_{uniform}$$
其中 $\mathcal{L}_{CE}^{recon}$ 是对全部 169 个位置的交叉熵重建损失(不做掩码,全图重建)。预训练完成后,解码头被丢弃,编码器权重作为联合训练的初始化。
### 解码头设计
解码头的职责是将 z_q [B, L, d_z] 还原为 [B, H*W, num_classes],有以下两种设计选项:
#### 选项 D-1Cross-Attention 解码头(推荐)
```
z_q [B, L, d_z]
可学习位置查询 [B, H*W, d_z](每个格子对应一个 query
│ Cross-Attentionquery=位置查询key/value=z_q
线性分类头 → logits [B, H*W, num_classes]
```
与 MaskGIT 的 Cross-Attention 结构高度一致,预训练阶段即可验证"z → 地图"的解码路径是否畅通。解码头参数量小(单层 Cross-Attention + Linear预训练速度快。
#### 选项 D-2简单线性展开基线
```
z_q [B, L, d_z]
│ Flatten → Linear
logits [B, H*W, num_classes]
```
实现最简单,但 L × d_z → H\*W × num_classes 的映射会引入大量参数L=32, d_z=128 时约 512K且缺乏空间归纳偏置效果可能较差。
**推荐选项 D-1**,结构与联合训练阶段的 MaskGIT 解码路径一致,预训练阶段已对"z 作为 Cross-Attention memory 驱动生成"这一机制进行充分热身。
### 训练策略
| 阶段 | 模型状态 | 目标 | 建议轮数 |
| -------------------- | ----------------------------- | ----------------------------------------- | ------------ |
| 阶段 0预训练 | 编码器 + 临时解码头 | 全图重建,$\mathcal{L}_{pretrain}$ 收敛 | 2050 epoch |
| 阶段 1联合热身 | 编码器冻结 + MaskGIT 训练 | 让 MaskGIT 先适应固定的 z 分布 | 2040 epoch |
| 阶段 2完整联合训练 | 全部参数解冻,编码器用较小 LR | 端到端联合优化(可叠加方案 A 一致性约束) | 正常训练轮数 |
> 阶段 1 的编码器冻结热身建议执行若直接解冻联合训练MaskGIT 早期的不稳定梯度可能逐渐覆盖编码器预训练获得的语义。考虑到 MaskGIT 收敛速度相对较慢,热身阶段建议适当延长至 2040 epoch。
### 实现要点
1. **解码头独立模块**:将解码头实现为独立的类(如 `VQDecodeHead`),不修改 `GinkaVQVAE` 的核心结构,预训练结束后直接丢弃,不影响联合训练代码路径。
2. **预训练脚本独立**:新增 `ginka/train_pretrain.py`,与联合训练脚本 `train_vq.py` 分离,便于单独调试。
3. **权重迁移**:预训练结束后通过 `model_vq.load_state_dict(ckpt['vq_state'], strict=False)` 将编码器权重加载到联合训练中。
4. **重建质量指标**:预训练阶段重点监控逐类别准确率(尤其是墙壁 tile=1 的召回率确认编码器已学到基本的空间结构语义。需注意codebook 容量远小于训练集数量,预训练的目标更倾向于让编码器学会地图的大致分类,而非像素级完整重建——重建损失在此主要作为分类学习的约束信号。
### 与其他方案的关系
- 方案 D 是**独立于方案 A/B 的训练流程优化**,不修改模型推理时的计算图,与方案 A 的一致性约束、方案 B 的多路编码均可叠加使用。
- 方案 D 完成后,方案 A 的一致性约束的初始条件更好(编码器已具有语义),收敛应更快、更稳定。
- 若最终采用方案 B多路分拆每个通道的编码器均可独立预训练后再联合训练。
---
## 两方案的对比
| 维度 | 方案 Az 闭环) | 方案 B多路分拆 |
| ---------------- | ---------------------------------- | ----------------------------------------- |
| 核心增益 | 强化现有 z 的控制力 | 增加 z 的语义分层,提升可解释性与可控性 |
| 实现复杂度 | 中等(需处理离散梯度问题) | 较高(需修改编码器输入管线 + 通道标识符) |
| 参数量变化 | 不变(同一编码器复用) | 小幅增加(独立头部 + 通道嵌入) |
| 训练稳定性 | 引入额外损失项,可能需要调权重 | 结构变化较大,需要较长的调试周期 |
| 推理灵活性 | 不变(仍为单路随机采样) | 提升(三路独立采样 / 条件指定) |
| 与现有方案兼容性 | 高(仅增加损失项,不改变模型结构) | 中等(需修改编码器 + dataset pipeline |
---
## 实施建议
### 阶段零:预训练编码器(方案 D可选但推荐
1. 实现 `VQDecodeHead`Cross-Attention 解码头)和独立预训练脚本 `ginka/train_pretrain.py`
2. 以全图重建为目标预训练 VQ 编码器 2050 epoch直至重建准确率尤其是墙壁类趋于稳定
3. 保存编码器权重,作为阶段一联合训练的初始化。
### 阶段一:验证方案 A低风险快速验证
1. 在现有联合训练代码中,对子集 A 的训练步骤增加软分布近似一致性损失;
2. 监控 `L_consist`、`codebook 使用熵`、`生成多样性pairwise distance` 三个指标;
3. 若 z 控制力提升明显(生成结果与参考地图结构相似度上升),方案 A 单独使用即可;
4. 若提升有限,再推进方案 B。
### 阶段二:实施方案 B较高收益中等成本
1. 修改 `data/src/auto.ts`,在序列化时输出三个通道的地图矩阵;
2. 修改 `ginka/dataset.py`,为每个样本加载三个通道的输入;
3. 修改编码器(`ginka/vqvae/model.py`),增加通道标识符 token
4. 修改 MaskGIT 的 cross-attention memory 拼接逻辑;
5. 复用方案 A 的一致性约束,仅作用于通道 3。
### 阶段三:多阶段生成(方案 C后续计划
先生成墙与入口的骨架,再生成门与怪物,最后生成资源与道具。此方案依赖方案 B 中通道 1 的编码器作为阶段间信息传递的载体,适合在方案 B 验证稳定后再推进。
---
## 待细化事项
- [x] 方案 A一致性损失的权重 $\lambda$ 如何随训练进度调度?→ 先使用常量(初始值 0.1),效果不佳再引入调度策略。
- [x] 方案 D预训练阶段是否对所有子集数据都进行预训练还是只用完整地图→ 仅使用完整地图raw_map。子集划分的差异体现在输入条件上但输出目标始终是完整地图预训练阶段无需区分子集。
- [x] 方案 D预训练完成后联合训练时编码器是否需要冻结热身阶段→ 建议执行冻结热身。若直接解冻联合训练MaskGIT 的不稳定梯度可能逐渐覆盖编码器预训练所获得的语义;考虑到 MaskGIT 收敛较慢,热身 epoch 数适当增大(建议 2040 epoch
- [x] 方案 A单步解码还是多步解码后计算一致性损失→ 训练时 MaskGIT 只进行单步解码,直接在单步结果上计算,无需多步展开。
- [x] 方案 B通道 2 的"墙壁"是否需要保留,还是只保留入口 + 怪 + 门?→ 保留墙壁。去掉墙壁后剩余内容趋向于散点,缺乏空间结构指导意义。
- [x] 方案 B三路 z 拼接后总长度是否超出 MaskGIT cross-attention 的合理 memory 长度?→ 先直接拼接,如有性能问题再评估截断或压缩策略。
- [ ] 方案 B通道 1/2 的编码器在用户未提供条件时如何优雅地回退到随机采样dropout 机制设计)?→ 暂不考虑分通道条件指定,等模型结构稳定后再设计此机制。