ginka-generator/docs/vqvae-maskgit-design.md
unanmed abbad781ab feat: 添加少数结构性标签
Co-authored-by: Copilot <copilot@github.com>
2026-04-27 14:56:21 +08:00

398 lines
17 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Ginka 地图生成模型重构设计文档
## 背景与动机
### 现状与问题
当前模型采用**热力图Heatmap**作为生成条件,引导 MaskGIT 模型生成地图。实践中发现以下问题:
1. **模型侧**MaskGIT 无法有效利用热力图信息,热力图对生成结果的控制力弱;
2. **用户侧**:用户难以手动构造语义合理的热力图,交互成本高;
3. **表达侧**:热力图是高维连续信号,携带冗余信息,但实际指导生成的有效信息量有限。
### 改进目标
- 保留 MaskGIT 的地图生成范式,但将生成条件改为**低维离散隐变量 z**
- z 由 **VQ-VAE 风格的编码器**从真实地图中抽取,形成一个可复用的 **codebook**
- 训练期间通过 codebook 为 MaskGIT 提供多样性控制信号;
- 推理期间**直接随机采样 z**,无需用户提供任何额外输入。
---
## 整体架构
系统由两个子模型组成,训练可以分阶段进行,也可以联合训练。
```
┌─────────────────────────────────────────────────────────┐
│ 训练阶段 │
│ │
│ 真实地图 ──► [VQ-VAE 编码器] ──► z离散码序列
│ │ │
│ ▼ │
│ [Codebook] │
│ │ │
│ ▼ │
│ 掩码地图 + z ──► [MaskGIT] ──► 预测被遮盖的图块 │
└─────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────┐
│ 推理阶段 │
│ │
│ 随机采样 z ──► [MaskGIT 迭代解码] ──► 生成地图 │
└─────────────────────────────────────────────────────────┘
```
---
## 子模型一VQ-VAE 风格编码器
**规模约束**:编码器参数量应控制在 **< 1M**。此前 MaskGIT 在热力图方案中已验证 4M 模型处于欠拟合与过拟合的边界改用新方案整体规模不会有大幅变化VQ-VAE 编码器作为辅助模块无需过大
### 职责
将一张完整的地图`[H, W]` 整数矩阵编码为一个**离散的 z**z 是从 codebook 中查得的嵌入向量或向量序列)。
### 设计方案
编码器统一采用** Transformer 结构**地图尺寸13×13不大Transformer 不会带来过大计算开销同时更擅长捕捉图块之间的长程依赖优于 CNN 方案
#### 方案 A全图单一码字z 为标量 index
- 编码器将整张地图压缩为一个 `d_z` 维向量再量化为 codebook 中的单个码字
- z 是一个整数 index对应 codebook 中一行嵌入
- 结构最简单但表达能力有限难以捕捉地图的局部多样性
```
地图 [H, W] ──► Transformer 编码 ──► [d_z] ──► 量化 ──► index标量
codebook[index] → z [d_z]
```
#### 方案 B空间码字序列z 为序列,**已选定**
- 编码器将地图编码为 `[L, d_z]` 的特征序列L 为码字数量
- 每个位置独立量化为 codebook 中的一个码字
- z `L` index 的组合表达能力更强
- 推理时随机采样 L index 即可
```
地图 [H, W] ──► Transformer 编码 ──► [L, d_z] ──► 逐位量化 ──► [L 个 index]
codebook[index_0..L] → z [L, d_z]
```
### VQ-VAE 量化机制
标准向量量化Vector Quantization
$$z_q = \arg\min_{e_k \in \mathcal{E}} \| z_e - e_k \|_2$$
其中 $\mathcal{E} = \{e_1, ..., e_K\}$ codebook$K$ 为码本大小
**直通估计Straight-Through Estimator**用于反向传播
$$\frac{\partial \mathcal{L}}{\partial z_e} \approx \frac{\partial \mathcal{L}}{\partial z_q}$$
### 损失函数设计
总损失由三部分组成
$$\mathcal{L}_{VQVAE} = \mathcal{L}_{recon} + \beta \cdot \mathcal{L}_{commit} + \gamma \cdot \mathcal{L}_{uniform}$$
| 损失项 | 说明 |
| ----------------------- | ----------------------------------------------------------- |
| $\mathcal{L}_{recon}$ | 重建损失确保 z 能恢复地图信息可选视是否加解码器而定 |
| $\mathcal{L}_{commit}$ | 承诺损失 $\|z_e - \text{sg}(z_q)\|^2$拉近编码向量与码字 |
| $\mathcal{L}_{uniform}$ | **均匀分布正则化**鼓励各码字被均等使用见下节 |
#### 均匀分布正则化
目标让所有 `K` 个码字在训练中被均等使用避免 codebook collapse
方案一**熵最大化**
$$\mathcal{L}_{uniform} = -H(p) = \sum_{k=1}^{K} p_k \log p_k$$
其中 $p_k$ 为码字 $k$ 在当前 batch 中被选中的频率使用 EMA 估计)。
方案二**EMA 更新 + 重置**
- 使用指数移动平均EMA更新 codebook不通过梯度更新
- 定期检测使用率低于阈值的码字将其重置为当前 batch 中随机样本点
方案三**codebook 使用 KL 散度**
$$\mathcal{L}_{uniform} = \text{KL}(p \| \mathcal{U}_{K})$$
其中 $\mathcal{U}_K$ 为均匀分布即期望每个码字被选中概率为 $1/K$。
**初始方案**采用**熵最大化**方案一通过梯度直接优化实现简单若后续出现 codebook collapse 问题再引入 EMA 更新 + 重置机制
### 是否需要解码器
| | 有解码器 | 无解码器 |
| ---------- | -------------------------- | ----------------------------------------- |
| 重建损失 | z 需要还原地图 | z 的质量由 MaskGIT 的生成质量间接约束 |
| 训练复杂度 | 较高 | 较低 |
| z 语义性 | z 明确包含地图结构信息 | z 更像风格/多样性控制 |
| 推荐场景 | 需要 z 携带结构语义 | 仅用于多样性控制 |
**已确认:不加解码器**直接与 MaskGIT 联合训练z 的语义由 MaskGIT 端的生成损失反向传播决定z 定位为风格与多样性控制信号而非结构重建指导—— z 能完整重建地图模型会忽略其他条件直接依赖 z反而降低可控性与泛化能力
相应地损失函数简化为
$$\mathcal{L} = \mathcal{L}_{CE}(\text{MaskGIT}) + \beta \cdot \mathcal{L}_{commit} + \gamma \cdot \mathcal{L}_{uniform}$$
---
## 子模型二MaskGIT 改造
### 现有输入
```python
def forward(self, map: torch.Tensor, heatmap: torch.Tensor):
# map: [B, H * W]
# heatmap: [B, C, H, W]
```
### 改造后输入
```python
def forward(self, map: torch.Tensor, z: torch.Tensor):
# map: [B, H * W] — 掩码后的地图 token 序列
# z: [B, L, d_z] — 离散隐变量方案B或 [B, d_z]方案A
```
移除原有的 `GinkaMaskGITCond`热力图 CNN 编码器 `gate_encoder` z 替代其功能
### z 注入 MaskGIT 的方式(待选)
#### 方式一Cross-Attention推荐
Transformer 解码器天然支持 cross-attention
```
map token sequence ──► self-attention ──► cross-attention ◄── z
预测 logits
```
z 作为 memory 输入到 `TransformerDecoder`map token 序列作为 query
优点z 可以是任意长度的序列方案B友好表达力强。**已选定此方案。**
#### 方式二z 拼接到序列头部Prefix Token
z 的每个码字嵌入作为前缀 token拼接到 map token 序列前
```
[z_0, z_1, ..., z_L, tile_0, tile_1, ..., tile_{H*W}]
```
输入到 TransformerEncoder 做自注意力位置编码需相应扩展
优点实现简单无需 encoder-decoder 分离
#### 方式三Add/FiLM 调制
z 投影后 FiLMFeature-wise Linear Modulation调制每一层的特征
$$y = \gamma(z) \odot x + \beta(z)$$
优点参数高效缺点z 必须先聚合为单向量适合方案A)。
---
## 用户使用场景
调研结果显示用户主要有以下两种使用意图
| 场景 | 描述 | 对应训练子集 |
| ---------------- | ------------------------------------ | ------------------------- |
| **完全随机生成** | 用户不提供任何条件直接生成完整地图 | 子集 C随机墙壁 + Mask |
| **墙壁辅助生成** | 用户手绘墙壁结构模型填充非墙内容 | 子集 B / D |
训练策略需针对这两种场景显式建模通过多子集混合训练使单一模型同时支持两种工作模式
---
## 训练策略
### 多子集混合训练
每个 batch 从四个子集中按比例采样各子集对模型能力的针对性不同
#### 子集 A标准 MaskGIT 训练(比例 0.4 ~ 0.6
保留原有 MaskGIT 掩码范式随机遮盖部分 tile训练模型基本的补全能力
```
原始地图 ──► 随机 Mask 部分 tile ──► MaskGIT 预测被遮盖 tile
```
- 掩码策略沿用现有 `MapMask`随机掩码 + 分块掩码 + 形态学变换
- 作用维持模型的基础生成质量防止针对性训练导致退化
#### 子集 B墙壁条件生成比例 0.1 ~ 0.3
去除地图中所有非墙 tile保留墙壁和空地让模型在给定墙壁结构下生成其余内容
```
原始地图 ──► 清除所有非墙 tile保留 空地/墙壁)──► MaskGIT 生成非墙内容
```
- 输入仅含墙壁tile=1和空地tile=0的地图
- 目标模型输出完整的地图含怪物道具钥匙等
- 对应场景**用户手绘墙壁 模型填充**
#### 子集 C随机条件生成比例 0.1 ~ 0.3
在子集 B 的基础上进一步对墙壁进行随机 Mask仅保留部分墙壁作为初始条件引入随机性以支持完全随机生成场景
```
原始地图 ──► 清除非墙 tile ──► 随机 Mask 部分墙壁 tile ──► MaskGIT 生成
```
- 输入随机保留的稀疏墙壁片段
- 目标生成完整地图含墙壁补全 + 非墙内容填充
- 对应场景**完全随机生成**墙壁本身也需要生成
- z 的随机采样在此子集中发挥关键作用控制生成风格多样性
#### 子集 D入口条件生成比例 0.1 ~ 0.2
在子集 B / C 的基础上额外保留入口 tiletile=10训练模型在给定入口位置时调整生成结果。
```
原始地图 ──► 清除非墙非入口 tile保留 空地/墙壁/入口)──► MaskGIT 生成
```
- 输入墙壁结构 + 入口位置
- 目标生成与入口位置语义一致的完整地图
- 对应场景**用户指定入口 模型生成与之匹配的布局**
- 单独设立此子集的原因是入口对地图整体拓扑结构有较强约束需要专项强化
### 各子集比例汇总
| 子集 | 训练任务 | 建议比例 |
| ---- | -------------------------------- | --------- |
| A | 标准 MaskGIT随机掩码补全 | 0.4 ~ 0.6 |
| B | 墙壁条件生成保留全部墙壁 | 0.1 ~ 0.3 |
| C | 随机条件生成随机保留部分墙壁 | 0.1 ~ 0.3 |
| D | 入口条件生成保留墙壁 + 入口 | 0.1 ~ 0.2 |
各子集比例之和为 1可作为 `dataset.py` 中采样权重使用
### 联合训练流程
VQ-VAE 编码器与改造后的 MaskGIT 一起训练
```
真实地图 ──► VQ-VAE 编码器 ──► z离散
按子集构造的掩码地图 + z ──► MaskGIT ──► 预测 logits ──► 交叉熵损失
VQ 损失commit + uniform
```
总损失
$$\mathcal{L} = \mathcal{L}_{CE}(\text{MaskGIT}) + \beta \cdot \mathcal{L}_{commit} + \gamma \cdot \mathcal{L}_{uniform}$$
### Dropout z提升鲁棒性
训练时以一定概率 10-20% z 替换为随机采样的码字模拟推理时随机采样 z 的情况避免模型过度依赖 z 的精确语义在子集 C完全随机生成中此机制尤为重要
---
## 推理流程
### 场景一:完全随机生成
```
1. 随机从 [0, K) 均匀采样 L 个 index
2. 从 codebook 查表,得到 z [L, d_z]
3. 初始化地图序列:以少量随机墙壁(比例 0.02 ~ 0.1)作为初始条件,其余位置填充 MASK token
4. 迭代 MaskGIT 解码cosine schedule
a. 输入掩码地图 + z → 预测所有 MASK 位置的 logits
b. 取置信度最高的 N 个 token 进行 unmask
c. 重复直到无 MASK token
5. 输出最终地图
```
**初始墙壁比例说明** MASK 初始化会导致生成多样性不佳即使使用 top-k 采样也难以改善)。引入少量随机墙壁2%~10%作为种子可有效提升多样性但比例不宜过高否则随机放置的墙壁可能产生拓扑结构冲突影响生成质量
### 场景二:墙壁辅助生成(用户手绘墙壁)
```
1. 用户提供墙壁布局tile=1 的位置),其余位置填充 MASK token
2. 随机从 [0, K) 均匀采样 L 个 index得到 z
3. 迭代 MaskGIT 解码:
a. 输入带已知墙壁的掩码地图 + z → 预测 MASK 位置的 logits
b. 已知墙壁位置不参与 unmask保持不变
c. 取置信度最高的 N 个 MASK token 进行 unmask
d. 重复直到无 MASK token
4. 输出最终地图
```
### 场景三:入口指定生成
```
与场景二类似但初始条件同时固定墙壁位置和入口位置tile=10
其余 MASK 位置由 MaskGIT 迭代解码填充
```
---
## 超参数(待确定)
| 参数 | 说明 | 建议初始值 |
| -------------- | --------------------- | ---------- |
| `K` | codebook 大小 | 8 ~ 32 |
| `L` | 码字序列长度方案B | 1 ~ 4 |
| `d_z` | 码字嵌入维度 | 64 ~ 128 |
| `β` | commit loss 权重 | 0.25 |
| `γ` | uniform loss 权重 | 0.1 |
| z dropout 概率 | 训练时随机 z 的比例 | 0.1 ~ 0.2 |
地图固定为 13×13z 定位为风格与多样性控制信号K 无需过大8~32 足够L 也应保持较小以避免过度约束生成
---
## 已决定事项
| 决策点 | 结论 |
| -------------- | -------------------------------------------------------------------------- |
| VQ-VAE 解码器 | **不加**直接与 MaskGIT 联合训练z 语义由生成损失反向传播决定 |
| 编码器架构方案 | **方案 B**序列码字 Transformer 结构若训练不稳定可退回方案 A |
| z 注入方式 | **Cross-Attention**方式一 |
| 均匀正则化策略 | **熵最大化**若出现 collapse 再引入 EMA |
| 分层 VQ | **单层**13×13 地图无需多级量化 |
| 标量 cond 保留 | **暂不使用**训练集标量分布严格推理阶段难以生成合理值后续可再考虑加入 |
## 待探索事项
- 合适的 KL 取值建议从 K=16, L=2 开始实验K=1, L=64 可能比较合适
- z dropout 的最优概率
- 若后续 codebook collapse引入 EMA 更新 + 重置机制
- 若后续需要更细粒度控制加入标量 cond需对推理侧标量做随机扰动处理
---
## 下一步行动
- [x] 确定编码方案方案 B序列码字
- [x] 确定 z 注入方式Cross-Attention
- [x] 确定是否需要 VQ-VAE 解码器不加
- [x] 确定是否保留标量 cond暂不使用
- [x] 确定训练子集划分与各场景比例
- [x] 实现 VQ-VAE 编码器模块Transformer + VQ
- [x] 改造 `GinkaMaskGIT.forward()` 接受 z替换热力图分支
- [x] 实现四种子集采样逻辑`dataset.py` 新增多子集 Dataset 或采样权重
- [x] 实现子集 B/C/D 的输入构造函数按规则清除 tile保留墙壁/入口
- [x] 编写联合训练脚本整合 VQ 损失与 MaskGIT 交叉熵损失