ginka-generator/docs/vqvae-split-channel-design.md
unanmed 3874d4dd95 feat: VQ 编码器分为三部分
Co-authored-by: Copilot <copilot@github.com>
2026-04-28 18:41:57 +08:00

323 lines
17 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# VQ 编码器三通道分拆预训练设计文档
## 背景与问题诊断
### 核心问题
在当前 VQ-VAE + MaskGIT 联合训练方案(以及方案 A 闭环约束、方案 D 全图预训练)中,一个根本性的类别不均衡问题始终未被解决:
地图中约 7085% 的格子为墙壁tile=1和空地tile=0而怪物9、门2、入口10、宝石4/5/6、钥匙3、药水7、道具8等功能性 tile 仅占少数。当编码器以完整地图为输入、重建损失对所有位置平等计算时:
- 编码器的梯度信号被墙壁/空地主导,特征向量主要编码了空间结构,而非功能性内容;
- 解码(或生成)时,怪物/门/资源类 tile 的召回率远低于墙壁;
- 即使方案 D 进行了全图预训练,由于损失仍是全位置等权 Focal Loss预训练阶段本身就强化了这种偏向
- **已将损失由 CE 替换为 Focal Loss但召回率仍无明显改善**——说明问题的根源在于梯度信号的来源比例,而非损失函数本身的形式;
- 方案 A 的闭环一致性约束在 z 本身质量不足时,仅能强化已有的偏向,难以从根本上改善。
### 改进目标
通过**按语义层次将地图拆分为三个独立通道**,采用**累积式输入 + 通道专属损失**的设计:每个通道的编码器输入包含当前通道及所有低等级 tile提供空间上下文而解码头的损失**仅在本通道专属 tile 的位置计算,低等级 tile 的损失权重为 0**。这样既保证了编码器能在有意义的空间结构中学习(避免功能性 tile 沦为孤立散点),又迫使解码器必须正确预测本通道 tile 才能降低损失——无法靠拟合高频的墙壁/空地来刷低损失值。
---
## Tile 类型划分
根据 `data/src/shared.ts` 中的定义,将 15 个有效 tile 类型(不含 MASK=15按游戏语义分为三个通道
| 通道 | 编码器输入(切片内容) | 解码损失计算范围 | 语义含义 |
| ------ | ---------------------------------------------------- | ------------------ | ------------------------------------ |
| 通道 1 | floor(0) + wall(1) | {1}(仅墙壁) | 空间骨架(地形结构) |
| 通道 2 | floor(0) + wall(1) + door(2) + mob(9) + entrance(10) | {2, 9, 10} | 关卡门控(交互元素,决定路径可达性) |
| 通道 3 | 完整地图(所有 tile | {3, 4, 5, 6, 7, 8} | 收集资源(奖励与道具) |
**通道划分的关键设计原则**
- 通道 1 仅包含 floor 和 wall编码器集中学习地图的空间骨架结构
- 通道 2 的输入**保留墙壁与空地作为空间上下文**,在此基础上叠加 door/mob/entrance确保编码器能感知功能性 tile 在地图中的位置关系,而非将其视为无空间依附的孤立散点;
- 通道 3 的输入为**完整地图**,编码器在包含骨架与关卡结构的完整上下文中学习资源 tile 的空间分布;
- 预训练时每个解码头的损失**仅在本通道专属 tile 的位置计算,低等级 tilefloor、wall 及前级通道的 tile损失权重为 0**——迫使编码器必须通过正确预测本通道 tile 来降低损失,无法靠拟合高频背景来规避优化压力。
---
## 整体架构
### 预训练阶段
```
完整地图 [B, H*W]
├──► 切片 1floor(0)+wall(1)(这已是全部内容,无需替换)
│ │
│ ▼
│ Encoder_1 → VQ_1 → z_1 [B, L_1, d_z]
│ │
│ ▼
│ DecodeHead_1 → logits [B, H*W, C]
│ │
│ ▼
│ Loss_1仅在 tile∈{1} 的位置计算 Focal Loss仅墙壁空地权重为 0
├──► 切片 2保留 floor(0)+wall(1)+door(2)+mob(9)+entry(10)其余→floor(0)
│ │
│ ▼
│ Encoder_2 → VQ_2 → z_2 [B, L_2, d_z]
│ │
│ ▼
│ DecodeHead_2 → logits [B, H*W, C]
│ │
│ ▼
│ Loss_2仅在 tile∈{2,9,10} 的位置计算 Focal Lossfloor/wall 权重为 0
└──► 切片 3完整地图所有 tile无需替换
Encoder_3 → VQ_3 → z_3 [B, L_3, d_z]
DecodeHead_3 → logits [B, H*W, C]
Loss_3仅在 tile∈{3,4,5,6,7,8} 的位置计算 Focal Loss其余 tile 权重为 0
```
三路编码器**相互独立预训练**,每路的预训练损失:
$$\mathcal{L}_{pretrain}^{(k)} = \mathcal{L}_{FL}^{(k)} + \beta \cdot \mathcal{L}_{commit}^{(k)} + \gamma \cdot \mathcal{L}_{uniform}^{(k)}$$
其中 $\mathcal{L}_{FL}^{(k)}$ 为通道 $k$ 的通道专属掩码 Focal Loss见下节
### 联合训练阶段
```
完整地图 ──► 三路切片 ──► [Enc_1, Enc_2, Enc_3] ──► [z_1, z_2, z_3]
z = Concat([z_1, z_2, z_3], dim=1)
掩码地图 + z ──► MaskGIT (Cross-Attention) ──► 预测 logits ──► Focal Loss
```
联合训练总损失(不含预训练解码头):
$$\mathcal{L}_{joint} = \mathcal{L}_{FL}^{MaskGIT} + \sum_{k=1}^{3} \left( \beta \cdot \mathcal{L}_{commit}^{(k)} + \gamma \cdot \mathcal{L}_{uniform}^{(k)} \right)$$
---
## 通道专属掩码 Focal Loss
这是方案的核心机制。对于通道 $k$,设其专属 tile 集合为 $\mathcal{T}_k$(不含低等级 tile则损失计算为
$$\mathcal{L}_{FL}^{(k)} = \frac{\sum_{i=1}^{H \times W} \mathbf{1}[y_i \in \mathcal{T}_k] \cdot \text{FL}(\hat{y}_i, y_i)}{\sum_{i=1}^{H \times W} \mathbf{1}[y_i \in \mathcal{T}_k] + \epsilon}$$
其中 $y_i$ 为真实 tile 类型,$\hat{y}_i$ 为解码头输出的 logits$\text{FL}$ 为 Focal Loss。
**实现方式**PyTorch 伪代码):
```python
# 通道 2 示例(输入切片已包含 floor+wall 作为上下文)
CHANNEL_2_TILES = {2, 9, 10} # door, mob, entrance
# target: 完整地图 ground truth[B, H*W]
# logits: DecodeHead_2 输出,[B, H*W, num_classes]
mask = torch.zeros_like(target, dtype=torch.bool)
for t in CHANNEL_2_TILES:
mask |= (target == t) # [B, H*W] bool仅通道 2 专属 tile 的位置为 True
# focal_loss: reduction='none',返回 [B * H*W]
fl = focal_loss(
logits.view(-1, num_classes),
target.view(-1),
) # [B * H*W]
fl = fl.view(B, -1) # [B, H*W]
loss_ch2 = (fl * mask).sum() / (mask.sum() + 1e-6)
```
**为什么输入包含墙壁、但损失不计算墙壁**:通道 2 的切片中保留了 floor+wall是为了给编码器提供空间结构上下文使门/怪/入口的位置有意义(否则孤立散点难以形成有效表示)。但损失仅在 `{2, 9, 10}` 位置计算,确保梯度信号完全来自这三类 tile——编码器如果只靠拟合高频的墙壁/空地,解码头在 `{2, 9, 10}` 位置的损失无法降低,从而被迫学习功能性 tile 的空间分布。
---
## 切片构造规则
| 通道 | 切片中保留的 tile | 其余位置替换为 | 解码损失计算的位置(专属 tile 集合 $\mathcal{T}_k$ |
| ---- | ----------------- | -------------- | ------------------------------------------------------- |
| 1 | 0, 1 | —(无需替换) | {1}(仅墙壁;空地是墙壁的补集,能预测墙壁即能区分空地) |
| 2 | 0, 1, 2, 9, 10 | 0floor | {2, 9, 10}floor/wall 损失权重为 0 |
| 3 | 全部(完整地图) | —(无需替换) | {3, 4, 5, 6, 7, 8}(其余 tile 损失权重为 0 |
---
## 编码器架构设计
### 三路独立编码器
三个编码器均复用现有的 `GinkaVQVAE` 类,配置略有差异:
| 参数 | Encoder_1结构骨架 | Encoder_2关卡门控 | Encoder_3收集资源 |
| --------------- | --------------------- | --------------------- | --------------------- |
| `L`(码字数) | 2 | 2 | 2 |
| `K`(码本大小) | 16 | 16 | 16 |
| `d_z` | 64 | 64 | 64 |
| `d_model` | 128 | 64 | 64 |
| `num_layers` | 2 | 2 | 2 |
- Encoder_1 处理高频 tile适当加大 `d_model`
- Encoder_2 和 Encoder_3 的功能性 tile 稀疏,信息量较小,可使用较小的 `d_model`
- 三路 `d_z` 保持一致,以便拼接后维度齐整;
- 总参数量估算Encoder_1 ~400K + Encoder_2 ~150K + Encoder_3 ~150K ≈ **700K**,在 1M 预算内。
> 考虑到训练集数量较少,可以考虑适当降低 K 和 L 的值避免模型死记硬背也可以防止训练集没有覆盖全部所有情况。1M 的参数量仅做估计,可以先尝试较大的参数量,如出现过拟合再降低。
### 联合训练时的 z 拼接
$$z = \text{Concat}([z_1, z_2, z_3], \dim=1) \in \mathbb{R}^{B \times (L_1+L_2+L_3) \times d_z}$$
以各通道 `L=2` 为例,总 memory 长度为 6与当前 MaskGIT Cross-Attention 的 memory 规模(原来 L=2相比略有增加但绝对长度仍很小不影响计算效率。
---
## 预训练流程
### 解码头复用
预训练时三路各自使用一个 `VQDecodeHead` 实例(现有类,`num_classes=16`),预训练结束后整体丢弃。解码头参数不迁移到联合训练阶段。
### 训练脚本
新增 `ginka/train_pretrain_split.py`(独立于现有的 `train_pretrain.py`
```python
# 伪代码结构
for epoch in ...:
for batch in dataloader:
raw_map = batch["raw_map"] # [B, H*W] 完整地图
# ─── 通道 1 ───
slice1 = make_slice(raw_map, keep={0, 1}) # floor+wall切片即完整输入
z_q1, z_e1, _, vq_loss1, *_ = enc1(slice1)
logits1 = head1(z_q1)
fl1 = masked_focal(logits1, raw_map, tile_set={1}) # 仅对 wall 计损失
loss1 = fl1 + vq_loss1
# ─── 通道 2 ───
slice2 = make_slice(raw_map, keep={0, 1, 2, 9, 10}) # 保留 wall/floor 作为上下文
z_q2, z_e2, _, vq_loss2, *_ = enc2(slice2)
logits2 = head2(z_q2)
fl2 = masked_focal(logits2, raw_map, tile_set={2, 9, 10}) # 仅对专属 tile 计损失
loss2 = fl2 + vq_loss2
# ─── 通道 3 ───
slice3 = raw_map # 完整地图,无需切片
z_q3, z_e3, _, vq_loss3, *_ = enc3(slice3)
logits3 = head3(z_q3)
fl3 = masked_focal(logits3, raw_map, tile_set={3, 4, 5, 6, 7, 8}) # 仅对专属 tile 计损失
loss3 = fl3 + vq_loss3
total = loss1 + loss2 + loss3
total.backward()
optimizer.step()
```
三路编码器可以**同步训练**(同一 optimizer也可以分别独立训练——独立训练更灵活可以对收敛速度差异大的通道单独调参。
### 预训练监控指标
| 指标 | 说明 |
| ----------------------- | ---------------------------------------------------------------- |
| 通道 1 wall 位置准确率 | Encoder_1 能否正确重建墙壁分布 |
| 通道 2 功能 tile 召回率 | Encoder_2 对 door/mob/entrance 各类的召回,应 > 50%(稀疏 tile |
| 通道 3 资源 tile 召回率 | Encoder_3 对各资源类的召回 |
| codebook 使用熵(各路) | 各通道 codebook 是否均匀使用,避免 collapse |
由于通道 2/3 的 tile 在每张地图中数量极少(典型地图中 door/mob/entrance 合计约 1020 格,资源合计约 1015 格),召回率指标比准确率更有意义。
---
## 联合训练流程
### 三阶段训练
| 阶段 | 模型状态 | 目标 | 建议轮数 |
| -------------------- | ----------------------------------------- | ------------------------------------------------ | ------------ |
| 阶段 0分通道预训练 | 三路 Encoder + 三路 DecodeHead | 各通道 Focal Loss 收敛,功能 tile 召回率达到目标 | 3060 epoch |
| 阶段 1冻结热身 | 三路 Encoder 冻结 + MaskGIT 全参训练 | MaskGIT 适应三路 z 的联合分布 | 2040 epoch |
| 阶段 2完整联合训练 | 全部参数解冻Encoder 使用较小 LR×0.1 | 端到端联合优化 | 正常训练轮数 |
### 联合训练数据集
`GinkaJointDataset` 需扩展为同时提供三路切片:
```python
# 返回字典新增字段
{
"raw_map": ..., # [H*W] 完整地图VQ 编码器输入)
"slice1": ..., # [H*W] 通道 1 切片
"slice2": ..., # [H*W] 通道 2 切片
"slice3": ..., # [H*W] 通道 3 切片
"masked_map": ..., # [H*W] MaskGIT 输入(掩码后地图)
"target_map": ..., # [H*W] MaskGIT CE ground truth
}
```
---
## 推理时的 z 采样
推理时三路编码器均独立采样,无需用户输入:
```python
# 完全随机生成
z1 = enc1.sample(B, device) # [B, L1, d_z]
z2 = enc2.sample(B, device) # [B, L2, d_z]
z3 = enc3.sample(B, device) # [B, L3, d_z]
z = torch.cat([z1, z2, z3], dim=1) # [B, L1+L2+L3, d_z]
```
**分通道条件控制**(可选扩展):
| 场景 | 通道 1 z 来源 | 通道 2 z 来源 | 通道 3 z 来源 |
| ------------ | ------------- | ------------- | ------------- |
| 完全随机生成 | 随机采样 | 随机采样 | 随机采样 |
| 指定墙壁布局 | 用户地图编码 | 随机采样 | 随机采样 |
| 指定关卡结构 | 随机采样 | 参考图编码 | 随机采样 |
| 风格迁移 | 参考图编码 | 参考图编码 | 随机采样 |
---
## 与现有方案的关系
| 方案 | 与本方案的关系 |
| ------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| 方案 A | 兼容。由于通道 2/3 的编码器输入包含墙壁上下文(而非孤立散点),其 z 具备稳定的空间语义,一致性约束可同样作用于三个通道:通道 1 约束骨架结构,通道 2 约束关卡元素分布,通道 3 约束资源分布 |
| 方案 D | 本方案**替代并强化**方案 D 的预训练思路:方案 D 是全图等权 Focal Loss 预训练,本方案通过累积式输入 + 通道专属损失从根本上解决了类别不均衡问题 |
| 方案 C | 兼容。多阶段生成(先墙后门后资源)可以将通道 1/2/3 的 z 分别作为各阶段的生成条件 |
---
## 超参数建议
| 参数 | 建议初始值 | 备注 |
| -------------------------- | ---------- | ---------------------------------------- |
| 各通道 `L`(码字数) | 2 | 三路合计 6可视效果适当扩大 |
| 各通道 `K`(码本大小) | 16 | 通道 2/3 可减小到 8tile 种类少) |
| `d_z` | 64 | 三路保持一致,便于拼接 |
| `β`commit loss 权重) | 0.25 | 同现有配置 |
| `γ`uniform loss 权重) | 0.1 | 通道 2/3 码本小,可适当增大到 0.2 |
| 预训练 epoch | 3060 | 以功能 tile 召回率达标为准,不以轮数为限 |
| 联合训练 Encoder LR 缩放比 | 0.1 | 阶段 2 解冻后使用较小 LR 微调 |
| z dropout 概率(联合训练) | 0.10.2 | 三路 z 各自独立 dropout |
---
## 实施步骤
- [ ]`ginka/dataset.py` 中实现 `make_slice(map, keep_set)` 辅助函数,生成三路切片
- [ ] 扩展 `GinkaJointDataset.__getitem__`,新增 `slice1/slice2/slice3` 字段
- [ ]`ginka/vqvae/model.py` 中确认 `GinkaVQVAE` 可独立实例化三次(无全局状态)
- [ ] 实现 `masked_focal(logits, target, tile_set)` 工具函数(`ginka/utils.py`
- [ ] 新增 `ginka/train_pretrain_split.py` 预训练脚本(支持三路同步或分路训练)
- [ ] 修改 `ginka/train_vq.py`(联合训练脚本),支持加载三路编码器权重并拼接 z
- [ ] 修改 `GinkaMaskGIT.forward()` 以接受 `[B, L1+L2+L3, d_z]` 的拼接 zCross-Attention memory
- [ ] 添加联合训练监控:各通道 codebook 使用熵、功能 tile 召回率