mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
398 lines
17 KiB
Markdown
398 lines
17 KiB
Markdown
# Ginka 地图生成模型重构设计文档
|
||
|
||
## 背景与动机
|
||
|
||
### 现状与问题
|
||
|
||
当前模型采用**热力图(Heatmap)**作为生成条件,引导 MaskGIT 模型生成地图。实践中发现以下问题:
|
||
|
||
1. **模型侧**:MaskGIT 无法有效利用热力图信息,热力图对生成结果的控制力弱;
|
||
2. **用户侧**:用户难以手动构造语义合理的热力图,交互成本高;
|
||
3. **表达侧**:热力图是高维连续信号,携带冗余信息,但实际指导生成的有效信息量有限。
|
||
|
||
### 改进目标
|
||
|
||
- 保留 MaskGIT 的地图生成范式,但将生成条件改为**低维离散隐变量 z**;
|
||
- z 由 **VQ-VAE 风格的编码器**从真实地图中抽取,形成一个可复用的 **codebook**;
|
||
- 训练期间通过 codebook 为 MaskGIT 提供多样性控制信号;
|
||
- 推理期间**直接随机采样 z**,无需用户提供任何额外输入。
|
||
|
||
---
|
||
|
||
## 整体架构
|
||
|
||
系统由两个子模型组成,训练可以分阶段进行,也可以联合训练。
|
||
|
||
```
|
||
┌─────────────────────────────────────────────────────────┐
|
||
│ 训练阶段 │
|
||
│ │
|
||
│ 真实地图 ──► [VQ-VAE 编码器] ──► z(离散码序列) │
|
||
│ │ │
|
||
│ ▼ │
|
||
│ [Codebook] │
|
||
│ │ │
|
||
│ ▼ │
|
||
│ 掩码地图 + z ──► [MaskGIT] ──► 预测被遮盖的图块 │
|
||
└─────────────────────────────────────────────────────────┘
|
||
|
||
┌─────────────────────────────────────────────────────────┐
|
||
│ 推理阶段 │
|
||
│ │
|
||
│ 随机采样 z ──► [MaskGIT 迭代解码] ──► 生成地图 │
|
||
└─────────────────────────────────────────────────────────┘
|
||
```
|
||
|
||
---
|
||
|
||
## 子模型一:VQ-VAE 风格编码器
|
||
|
||
**规模约束**:编码器参数量应控制在 **< 1M**。此前 MaskGIT 在热力图方案中已验证 4M 模型处于欠拟合与过拟合的边界,改用新方案整体规模不会有大幅变化,VQ-VAE 编码器作为辅助模块无需过大。
|
||
|
||
### 职责
|
||
|
||
将一张完整的地图(`[H, W]` 整数矩阵)编码为一个**离散的 z**,z 是从 codebook 中查得的嵌入向量(或向量序列)。
|
||
|
||
### 设计方案
|
||
|
||
编码器统一采用**纯 Transformer 结构**:地图尺寸(13×13)不大,Transformer 不会带来过大计算开销,同时更擅长捕捉图块之间的长程依赖,优于 CNN 方案。
|
||
|
||
#### 方案 A:全图单一码字(z 为标量 index)
|
||
|
||
- 编码器将整张地图压缩为一个 `d_z` 维向量,再量化为 codebook 中的单个码字;
|
||
- z 是一个整数 index,对应 codebook 中一行嵌入;
|
||
- 结构最简单,但表达能力有限,难以捕捉地图的局部多样性。
|
||
|
||
```
|
||
地图 [H, W] ──► Transformer 编码 ──► [d_z] ──► 量化 ──► index(标量)
|
||
│
|
||
▼
|
||
codebook[index] → z [d_z]
|
||
```
|
||
|
||
#### 方案 B:空间码字序列(z 为序列,**已选定**)
|
||
|
||
- 编码器将地图编码为 `[L, d_z]` 的特征序列(L 为码字数量);
|
||
- 每个位置独立量化为 codebook 中的一个码字;
|
||
- z 是 `L` 个 index 的组合,表达能力更强;
|
||
- 推理时随机采样 L 个 index 即可。
|
||
|
||
```
|
||
地图 [H, W] ──► Transformer 编码 ──► [L, d_z] ──► 逐位量化 ──► [L 个 index]
|
||
│
|
||
▼
|
||
codebook[index_0..L] → z [L, d_z]
|
||
```
|
||
|
||
### VQ-VAE 量化机制
|
||
|
||
标准向量量化(Vector Quantization):
|
||
|
||
$$z_q = \arg\min_{e_k \in \mathcal{E}} \| z_e - e_k \|_2$$
|
||
|
||
其中 $\mathcal{E} = \{e_1, ..., e_K\}$ 为 codebook,$K$ 为码本大小。
|
||
|
||
**直通估计(Straight-Through Estimator)**用于反向传播:
|
||
|
||
$$\frac{\partial \mathcal{L}}{\partial z_e} \approx \frac{\partial \mathcal{L}}{\partial z_q}$$
|
||
|
||
### 损失函数设计
|
||
|
||
总损失由三部分组成:
|
||
|
||
$$\mathcal{L}_{VQVAE} = \mathcal{L}_{recon} + \beta \cdot \mathcal{L}_{commit} + \gamma \cdot \mathcal{L}_{uniform}$$
|
||
|
||
| 损失项 | 说明 |
|
||
| ----------------------- | ----------------------------------------------------------- |
|
||
| $\mathcal{L}_{recon}$ | 重建损失,确保 z 能恢复地图信息(可选,视是否加解码器而定) |
|
||
| $\mathcal{L}_{commit}$ | 承诺损失 $\|z_e - \text{sg}(z_q)\|^2$,拉近编码向量与码字 |
|
||
| $\mathcal{L}_{uniform}$ | **均匀分布正则化**,鼓励各码字被均等使用(见下节) |
|
||
|
||
#### 均匀分布正则化
|
||
|
||
目标:让所有 `K` 个码字在训练中被均等使用,避免 codebook collapse。
|
||
|
||
方案一:**熵最大化**
|
||
|
||
$$\mathcal{L}_{uniform} = -H(p) = \sum_{k=1}^{K} p_k \log p_k$$
|
||
|
||
其中 $p_k$ 为码字 $k$ 在当前 batch 中被选中的频率(使用 EMA 估计)。
|
||
|
||
方案二:**EMA 更新 + 重置**
|
||
|
||
- 使用指数移动平均(EMA)更新 codebook,不通过梯度更新;
|
||
- 定期检测使用率低于阈值的码字,将其重置为当前 batch 中随机样本点。
|
||
|
||
方案三:**codebook 使用 KL 散度**
|
||
|
||
$$\mathcal{L}_{uniform} = \text{KL}(p \| \mathcal{U}_{K})$$
|
||
|
||
其中 $\mathcal{U}_K$ 为均匀分布,即期望每个码字被选中概率为 $1/K$。
|
||
|
||
**初始方案**:采用**熵最大化**(方案一),通过梯度直接优化,实现简单;若后续出现 codebook collapse 问题,再引入 EMA 更新 + 重置机制。
|
||
|
||
### 是否需要解码器
|
||
|
||
| | 有解码器 | 无解码器 |
|
||
| ---------- | -------------------------- | ----------------------------------------- |
|
||
| 重建损失 | 有,z 需要还原地图 | 无,z 的质量由 MaskGIT 的生成质量间接约束 |
|
||
| 训练复杂度 | 较高 | 较低 |
|
||
| z 语义性 | 强,z 明确包含地图结构信息 | 弱,z 更像风格/多样性控制 |
|
||
| 推荐场景 | 需要 z 携带结构语义 | 仅用于多样性控制 |
|
||
|
||
**已确认:不加解码器**,直接与 MaskGIT 联合训练。z 的语义由 MaskGIT 端的生成损失反向传播决定。z 定位为风格与多样性控制信号,而非结构重建指导——若 z 能完整重建地图,模型会忽略其他条件直接依赖 z,反而降低可控性与泛化能力。
|
||
|
||
相应地,损失函数简化为:
|
||
|
||
$$\mathcal{L} = \mathcal{L}_{CE}(\text{MaskGIT}) + \beta \cdot \mathcal{L}_{commit} + \gamma \cdot \mathcal{L}_{uniform}$$
|
||
|
||
---
|
||
|
||
## 子模型二:MaskGIT 改造
|
||
|
||
### 现有输入
|
||
|
||
```python
|
||
def forward(self, map: torch.Tensor, heatmap: torch.Tensor):
|
||
# map: [B, H * W]
|
||
# heatmap: [B, C, H, W]
|
||
```
|
||
|
||
### 改造后输入
|
||
|
||
```python
|
||
def forward(self, map: torch.Tensor, z: torch.Tensor):
|
||
# map: [B, H * W] — 掩码后的地图 token 序列
|
||
# z: [B, L, d_z] — 离散隐变量(方案B)或 [B, d_z](方案A)
|
||
```
|
||
|
||
移除原有的 `GinkaMaskGITCond`(热力图 CNN 编码器)和 `gate_encoder`,用 z 替代其功能。
|
||
|
||
### z 注入 MaskGIT 的方式(待选)
|
||
|
||
#### 方式一:Cross-Attention(推荐)
|
||
|
||
Transformer 解码器天然支持 cross-attention:
|
||
|
||
```
|
||
map token sequence ──► self-attention ──► cross-attention ◄── z
|
||
│
|
||
▼
|
||
预测 logits
|
||
```
|
||
|
||
z 作为 memory 输入到 `TransformerDecoder`,map token 序列作为 query。
|
||
|
||
优点:z 可以是任意长度的序列(方案B友好),表达力强。**已选定此方案。**
|
||
|
||
#### 方式二:z 拼接到序列头部(Prefix Token)
|
||
|
||
将 z 的每个码字嵌入作为前缀 token,拼接到 map token 序列前:
|
||
|
||
```
|
||
[z_0, z_1, ..., z_L, tile_0, tile_1, ..., tile_{H*W}]
|
||
```
|
||
|
||
输入到 TransformerEncoder 做自注意力,位置编码需相应扩展。
|
||
|
||
优点:实现简单,无需 encoder-decoder 分离。
|
||
|
||
#### 方式三:Add/FiLM 调制
|
||
|
||
将 z 投影后,用 FiLM(Feature-wise Linear Modulation)调制每一层的特征:
|
||
|
||
$$y = \gamma(z) \odot x + \beta(z)$$
|
||
|
||
优点:参数高效;缺点:z 必须先聚合为单向量(适合方案A)。
|
||
|
||
---
|
||
|
||
## 用户使用场景
|
||
|
||
调研结果显示,用户主要有以下两种使用意图:
|
||
|
||
| 场景 | 描述 | 对应训练子集 |
|
||
| ---------------- | ------------------------------------ | ------------------------- |
|
||
| **完全随机生成** | 用户不提供任何条件,直接生成完整地图 | 子集 C(随机墙壁 + Mask) |
|
||
| **墙壁辅助生成** | 用户手绘墙壁结构,模型填充非墙内容 | 子集 B / D |
|
||
|
||
训练策略需针对这两种场景显式建模,通过多子集混合训练使单一模型同时支持两种工作模式。
|
||
|
||
---
|
||
|
||
## 训练策略
|
||
|
||
### 多子集混合训练
|
||
|
||
每个 batch 从四个子集中按比例采样,各子集对模型能力的针对性不同:
|
||
|
||
#### 子集 A:标准 MaskGIT 训练(比例 0.4 ~ 0.6)
|
||
|
||
保留原有 MaskGIT 掩码范式,随机遮盖部分 tile,训练模型基本的补全能力。
|
||
|
||
```
|
||
原始地图 ──► 随机 Mask 部分 tile ──► MaskGIT 预测被遮盖 tile
|
||
```
|
||
|
||
- 掩码策略:沿用现有 `MapMask`(随机掩码 + 分块掩码 + 形态学变换)
|
||
- 作用:维持模型的基础生成质量,防止针对性训练导致退化
|
||
|
||
#### 子集 B:墙壁条件生成(比例 0.1 ~ 0.3)
|
||
|
||
去除地图中所有非墙 tile(保留墙壁和空地),让模型在给定墙壁结构下生成其余内容。
|
||
|
||
```
|
||
原始地图 ──► 清除所有非墙 tile(保留 空地/墙壁)──► MaskGIT 生成非墙内容
|
||
```
|
||
|
||
- 输入:仅含墙壁(tile=1)和空地(tile=0)的地图
|
||
- 目标:模型输出完整的地图(含怪物、道具、门、钥匙等)
|
||
- 对应场景:**用户手绘墙壁 → 模型填充**
|
||
|
||
#### 子集 C:随机条件生成(比例 0.1 ~ 0.3)
|
||
|
||
在子集 B 的基础上,进一步对墙壁进行随机 Mask,仅保留部分墙壁作为初始条件,引入随机性以支持完全随机生成场景。
|
||
|
||
```
|
||
原始地图 ──► 清除非墙 tile ──► 随机 Mask 部分墙壁 tile ──► MaskGIT 生成
|
||
```
|
||
|
||
- 输入:随机保留的稀疏墙壁片段
|
||
- 目标:生成完整地图(含墙壁补全 + 非墙内容填充)
|
||
- 对应场景:**完全随机生成**(墙壁本身也需要生成)
|
||
- 注:z 的随机采样在此子集中发挥关键作用,控制生成风格多样性
|
||
|
||
#### 子集 D:入口条件生成(比例 0.1 ~ 0.2)
|
||
|
||
在子集 B / C 的基础上额外保留入口 tile(tile=10),训练模型在给定入口位置时调整生成结果。
|
||
|
||
```
|
||
原始地图 ──► 清除非墙非入口 tile(保留 空地/墙壁/入口)──► MaskGIT 生成
|
||
```
|
||
|
||
- 输入:墙壁结构 + 入口位置
|
||
- 目标:生成与入口位置语义一致的完整地图
|
||
- 对应场景:**用户指定入口 → 模型生成与之匹配的布局**
|
||
- 注:单独设立此子集的原因是入口对地图整体拓扑结构有较强约束,需要专项强化
|
||
|
||
### 各子集比例汇总
|
||
|
||
| 子集 | 训练任务 | 建议比例 |
|
||
| ---- | -------------------------------- | --------- |
|
||
| A | 标准 MaskGIT(随机掩码补全) | 0.4 ~ 0.6 |
|
||
| B | 墙壁条件生成(保留全部墙壁) | 0.1 ~ 0.3 |
|
||
| C | 随机条件生成(随机保留部分墙壁) | 0.1 ~ 0.3 |
|
||
| D | 入口条件生成(保留墙壁 + 入口) | 0.1 ~ 0.2 |
|
||
|
||
各子集比例之和为 1,可作为 `dataset.py` 中采样权重使用。
|
||
|
||
### 联合训练流程
|
||
|
||
将 VQ-VAE 编码器与改造后的 MaskGIT 一起训练:
|
||
|
||
```
|
||
真实地图 ──► VQ-VAE 编码器 ──► z(离散)
|
||
│
|
||
按子集构造的掩码地图 + z ──► MaskGIT ──► 预测 logits ──► 交叉熵损失
|
||
│
|
||
▼
|
||
VQ 损失(commit + uniform)
|
||
```
|
||
|
||
总损失:
|
||
|
||
$$\mathcal{L} = \mathcal{L}_{CE}(\text{MaskGIT}) + \beta \cdot \mathcal{L}_{commit} + \gamma \cdot \mathcal{L}_{uniform}$$
|
||
|
||
### Dropout z(提升鲁棒性)
|
||
|
||
训练时以一定概率(如 10-20%)将 z 替换为随机采样的码字,模拟推理时随机采样 z 的情况,避免模型过度依赖 z 的精确语义。在子集 C(完全随机生成)中此机制尤为重要。
|
||
|
||
---
|
||
|
||
## 推理流程
|
||
|
||
### 场景一:完全随机生成
|
||
|
||
```
|
||
1. 随机从 [0, K) 均匀采样 L 个 index
|
||
2. 从 codebook 查表,得到 z [L, d_z]
|
||
3. 初始化地图序列:以少量随机墙壁(比例 0.02 ~ 0.1)作为初始条件,其余位置填充 MASK token
|
||
4. 迭代 MaskGIT 解码(cosine schedule):
|
||
a. 输入掩码地图 + z → 预测所有 MASK 位置的 logits
|
||
b. 取置信度最高的 N 个 token 进行 unmask
|
||
c. 重复直到无 MASK token
|
||
5. 输出最终地图
|
||
```
|
||
|
||
**初始墙壁比例说明**:全 MASK 初始化会导致生成多样性不佳(即使使用 top-k 采样也难以改善)。引入少量随机墙壁(2%~10%)作为种子,可有效提升多样性;但比例不宜过高,否则随机放置的墙壁可能产生拓扑结构冲突,影响生成质量。
|
||
|
||
### 场景二:墙壁辅助生成(用户手绘墙壁)
|
||
|
||
```
|
||
1. 用户提供墙壁布局(tile=1 的位置),其余位置填充 MASK token
|
||
2. 随机从 [0, K) 均匀采样 L 个 index,得到 z
|
||
3. 迭代 MaskGIT 解码:
|
||
a. 输入带已知墙壁的掩码地图 + z → 预测 MASK 位置的 logits
|
||
b. 已知墙壁位置不参与 unmask,保持不变
|
||
c. 取置信度最高的 N 个 MASK token 进行 unmask
|
||
d. 重复直到无 MASK token
|
||
4. 输出最终地图
|
||
```
|
||
|
||
### 场景三:入口指定生成
|
||
|
||
```
|
||
与场景二类似,但初始条件同时固定墙壁位置和入口位置(tile=10)
|
||
其余 MASK 位置由 MaskGIT 迭代解码填充
|
||
```
|
||
|
||
---
|
||
|
||
## 超参数(待确定)
|
||
|
||
| 参数 | 说明 | 建议初始值 |
|
||
| -------------- | --------------------- | ---------- |
|
||
| `K` | codebook 大小 | 8 ~ 32 |
|
||
| `L` | 码字序列长度(方案B) | 1 ~ 4 |
|
||
| `d_z` | 码字嵌入维度 | 64 ~ 128 |
|
||
| `β` | commit loss 权重 | 0.25 |
|
||
| `γ` | uniform loss 权重 | 0.1 |
|
||
| z dropout 概率 | 训练时随机 z 的比例 | 0.1 ~ 0.2 |
|
||
|
||
地图固定为 13×13,z 定位为风格与多样性控制信号,K 无需过大(8~32 足够),L 也应保持较小以避免过度约束生成。
|
||
|
||
---
|
||
|
||
## 已决定事项
|
||
|
||
| 决策点 | 结论 |
|
||
| -------------- | -------------------------------------------------------------------------- |
|
||
| VQ-VAE 解码器 | **不加**,直接与 MaskGIT 联合训练,z 语义由生成损失反向传播决定 |
|
||
| 编码器架构方案 | **方案 B**(序列码字),全 Transformer 结构;若训练不稳定可退回方案 A |
|
||
| z 注入方式 | **Cross-Attention**(方式一) |
|
||
| 均匀正则化策略 | **熵最大化**;若出现 collapse 再引入 EMA |
|
||
| 分层 VQ | **单层**,13×13 地图无需多级量化 |
|
||
| 标量 cond 保留 | **暂不使用**;训练集标量分布严格,推理阶段难以生成合理值,后续可再考虑加入 |
|
||
|
||
## 待探索事项
|
||
|
||
- 合适的 K、L 取值(建议从 K=16, L=2 开始实验),K=1, L=64 可能比较合适。
|
||
- z dropout 的最优概率
|
||
- 若后续 codebook collapse:引入 EMA 更新 + 重置机制
|
||
- 若后续需要更细粒度控制:加入标量 cond(需对推理侧标量做随机扰动处理)
|
||
|
||
---
|
||
|
||
## 下一步行动
|
||
|
||
- [x] 确定编码方案(方案 B,序列码字)
|
||
- [x] 确定 z 注入方式(Cross-Attention)
|
||
- [x] 确定是否需要 VQ-VAE 解码器(不加)
|
||
- [x] 确定是否保留标量 cond(暂不使用)
|
||
- [x] 确定训练子集划分与各场景比例
|
||
- [x] 实现 VQ-VAE 编码器模块(Transformer + VQ)
|
||
- [x] 改造 `GinkaMaskGIT.forward()` 接受 z,替换热力图分支
|
||
- [x] 实现四种子集采样逻辑(`dataset.py` 新增多子集 Dataset 或采样权重)
|
||
- [x] 实现子集 B/C/D 的输入构造函数(按规则清除 tile、保留墙壁/入口)
|
||
- [x] 编写联合训练脚本,整合 VQ 损失与 MaskGIT 交叉熵损失
|