mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
refactor: 采用 VQ + MaskGIT 方案
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
1eda704986
commit
068940cae0
397
docs/vqvae-maskgit-design.md
Normal file
397
docs/vqvae-maskgit-design.md
Normal file
@ -0,0 +1,397 @@
|
||||
# Ginka 地图生成模型重构设计文档
|
||||
|
||||
## 背景与动机
|
||||
|
||||
### 现状与问题
|
||||
|
||||
当前模型采用**热力图(Heatmap)**作为生成条件,引导 MaskGIT 模型生成地图。实践中发现以下问题:
|
||||
|
||||
1. **模型侧**:MaskGIT 无法有效利用热力图信息,热力图对生成结果的控制力弱;
|
||||
2. **用户侧**:用户难以手动构造语义合理的热力图,交互成本高;
|
||||
3. **表达侧**:热力图是高维连续信号,携带冗余信息,但实际指导生成的有效信息量有限。
|
||||
|
||||
### 改进目标
|
||||
|
||||
- 保留 MaskGIT 的地图生成范式,但将生成条件改为**低维离散隐变量 z**;
|
||||
- z 由 **VQ-VAE 风格的编码器**从真实地图中抽取,形成一个可复用的 **codebook**;
|
||||
- 训练期间通过 codebook 为 MaskGIT 提供多样性控制信号;
|
||||
- 推理期间**直接随机采样 z**,无需用户提供任何额外输入。
|
||||
|
||||
---
|
||||
|
||||
## 整体架构
|
||||
|
||||
系统由两个子模型组成,训练可以分阶段进行,也可以联合训练。
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ 训练阶段 │
|
||||
│ │
|
||||
│ 真实地图 ──► [VQ-VAE 编码器] ──► z(离散码序列) │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ [Codebook] │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ 掩码地图 + z ──► [MaskGIT] ──► 预测被遮盖的图块 │
|
||||
└─────────────────────────────────────────────────────────┘
|
||||
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ 推理阶段 │
|
||||
│ │
|
||||
│ 随机采样 z ──► [MaskGIT 迭代解码] ──► 生成地图 │
|
||||
└─────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 子模型一:VQ-VAE 风格编码器
|
||||
|
||||
**规模约束**:编码器参数量应控制在 **< 1M**。此前 MaskGIT 在热力图方案中已验证 4M 模型处于欠拟合与过拟合的边界,改用新方案整体规模不会有大幅变化,VQ-VAE 编码器作为辅助模块无需过大。
|
||||
|
||||
### 职责
|
||||
|
||||
将一张完整的地图(`[H, W]` 整数矩阵)编码为一个**离散的 z**,z 是从 codebook 中查得的嵌入向量(或向量序列)。
|
||||
|
||||
### 设计方案
|
||||
|
||||
编码器统一采用**纯 Transformer 结构**:地图尺寸(13×13)不大,Transformer 不会带来过大计算开销,同时更擅长捕捉图块之间的长程依赖,优于 CNN 方案。
|
||||
|
||||
#### 方案 A:全图单一码字(z 为标量 index)
|
||||
|
||||
- 编码器将整张地图压缩为一个 `d_z` 维向量,再量化为 codebook 中的单个码字;
|
||||
- z 是一个整数 index,对应 codebook 中一行嵌入;
|
||||
- 结构最简单,但表达能力有限,难以捕捉地图的局部多样性。
|
||||
|
||||
```
|
||||
地图 [H, W] ──► Transformer 编码 ──► [d_z] ──► 量化 ──► index(标量)
|
||||
│
|
||||
▼
|
||||
codebook[index] → z [d_z]
|
||||
```
|
||||
|
||||
#### 方案 B:空间码字序列(z 为序列,**已选定**)
|
||||
|
||||
- 编码器将地图编码为 `[L, d_z]` 的特征序列(L 为码字数量);
|
||||
- 每个位置独立量化为 codebook 中的一个码字;
|
||||
- z 是 `L` 个 index 的组合,表达能力更强;
|
||||
- 推理时随机采样 L 个 index 即可。
|
||||
|
||||
```
|
||||
地图 [H, W] ──► Transformer 编码 ──► [L, d_z] ──► 逐位量化 ──► [L 个 index]
|
||||
│
|
||||
▼
|
||||
codebook[index_0..L] → z [L, d_z]
|
||||
```
|
||||
|
||||
### VQ-VAE 量化机制
|
||||
|
||||
标准向量量化(Vector Quantization):
|
||||
|
||||
$$z_q = \arg\min_{e_k \in \mathcal{E}} \| z_e - e_k \|_2$$
|
||||
|
||||
其中 $\mathcal{E} = \{e_1, ..., e_K\}$ 为 codebook,$K$ 为码本大小。
|
||||
|
||||
**直通估计(Straight-Through Estimator)**用于反向传播:
|
||||
|
||||
$$\frac{\partial \mathcal{L}}{\partial z_e} \approx \frac{\partial \mathcal{L}}{\partial z_q}$$
|
||||
|
||||
### 损失函数设计
|
||||
|
||||
总损失由三部分组成:
|
||||
|
||||
$$\mathcal{L}_{VQVAE} = \mathcal{L}_{recon} + \beta \cdot \mathcal{L}_{commit} + \gamma \cdot \mathcal{L}_{uniform}$$
|
||||
|
||||
| 损失项 | 说明 |
|
||||
| ----------------------- | ----------------------------------------------------------- |
|
||||
| $\mathcal{L}_{recon}$ | 重建损失,确保 z 能恢复地图信息(可选,视是否加解码器而定) |
|
||||
| $\mathcal{L}_{commit}$ | 承诺损失 $\|z_e - \text{sg}(z_q)\|^2$,拉近编码向量与码字 |
|
||||
| $\mathcal{L}_{uniform}$ | **均匀分布正则化**,鼓励各码字被均等使用(见下节) |
|
||||
|
||||
#### 均匀分布正则化
|
||||
|
||||
目标:让所有 `K` 个码字在训练中被均等使用,避免 codebook collapse。
|
||||
|
||||
方案一:**熵最大化**
|
||||
|
||||
$$\mathcal{L}_{uniform} = -H(p) = \sum_{k=1}^{K} p_k \log p_k$$
|
||||
|
||||
其中 $p_k$ 为码字 $k$ 在当前 batch 中被选中的频率(使用 EMA 估计)。
|
||||
|
||||
方案二:**EMA 更新 + 重置**
|
||||
|
||||
- 使用指数移动平均(EMA)更新 codebook,不通过梯度更新;
|
||||
- 定期检测使用率低于阈值的码字,将其重置为当前 batch 中随机样本点。
|
||||
|
||||
方案三:**codebook 使用 KL 散度**
|
||||
|
||||
$$\mathcal{L}_{uniform} = \text{KL}(p \| \mathcal{U}_{K})$$
|
||||
|
||||
其中 $\mathcal{U}_K$ 为均匀分布,即期望每个码字被选中概率为 $1/K$。
|
||||
|
||||
**初始方案**:采用**熵最大化**(方案一),通过梯度直接优化,实现简单;若后续出现 codebook collapse 问题,再引入 EMA 更新 + 重置机制。
|
||||
|
||||
### 是否需要解码器
|
||||
|
||||
| | 有解码器 | 无解码器 |
|
||||
| ---------- | -------------------------- | ----------------------------------------- |
|
||||
| 重建损失 | 有,z 需要还原地图 | 无,z 的质量由 MaskGIT 的生成质量间接约束 |
|
||||
| 训练复杂度 | 较高 | 较低 |
|
||||
| z 语义性 | 强,z 明确包含地图结构信息 | 弱,z 更像风格/多样性控制 |
|
||||
| 推荐场景 | 需要 z 携带结构语义 | 仅用于多样性控制 |
|
||||
|
||||
**已确认:不加解码器**,直接与 MaskGIT 联合训练。z 的语义由 MaskGIT 端的生成损失反向传播决定。z 定位为风格与多样性控制信号,而非结构重建指导——若 z 能完整重建地图,模型会忽略其他条件直接依赖 z,反而降低可控性与泛化能力。
|
||||
|
||||
相应地,损失函数简化为:
|
||||
|
||||
$$\mathcal{L} = \mathcal{L}_{CE}(\text{MaskGIT}) + \beta \cdot \mathcal{L}_{commit} + \gamma \cdot \mathcal{L}_{uniform}$$
|
||||
|
||||
---
|
||||
|
||||
## 子模型二:MaskGIT 改造
|
||||
|
||||
### 现有输入
|
||||
|
||||
```python
|
||||
def forward(self, map: torch.Tensor, heatmap: torch.Tensor):
|
||||
# map: [B, H * W]
|
||||
# heatmap: [B, C, H, W]
|
||||
```
|
||||
|
||||
### 改造后输入
|
||||
|
||||
```python
|
||||
def forward(self, map: torch.Tensor, z: torch.Tensor):
|
||||
# map: [B, H * W] — 掩码后的地图 token 序列
|
||||
# z: [B, L, d_z] — 离散隐变量(方案B)或 [B, d_z](方案A)
|
||||
```
|
||||
|
||||
移除原有的 `GinkaMaskGITCond`(热力图 CNN 编码器)和 `gate_encoder`,用 z 替代其功能。
|
||||
|
||||
### z 注入 MaskGIT 的方式(待选)
|
||||
|
||||
#### 方式一:Cross-Attention(推荐)
|
||||
|
||||
Transformer 解码器天然支持 cross-attention:
|
||||
|
||||
```
|
||||
map token sequence ──► self-attention ──► cross-attention ◄── z
|
||||
│
|
||||
▼
|
||||
预测 logits
|
||||
```
|
||||
|
||||
z 作为 memory 输入到 `TransformerDecoder`,map token 序列作为 query。
|
||||
|
||||
优点:z 可以是任意长度的序列(方案B友好),表达力强。**已选定此方案。**
|
||||
|
||||
#### 方式二:z 拼接到序列头部(Prefix Token)
|
||||
|
||||
将 z 的每个码字嵌入作为前缀 token,拼接到 map token 序列前:
|
||||
|
||||
```
|
||||
[z_0, z_1, ..., z_L, tile_0, tile_1, ..., tile_{H*W}]
|
||||
```
|
||||
|
||||
输入到 TransformerEncoder 做自注意力,位置编码需相应扩展。
|
||||
|
||||
优点:实现简单,无需 encoder-decoder 分离。
|
||||
|
||||
#### 方式三:Add/FiLM 调制
|
||||
|
||||
将 z 投影后,用 FiLM(Feature-wise Linear Modulation)调制每一层的特征:
|
||||
|
||||
$$y = \gamma(z) \odot x + \beta(z)$$
|
||||
|
||||
优点:参数高效;缺点:z 必须先聚合为单向量(适合方案A)。
|
||||
|
||||
---
|
||||
|
||||
## 用户使用场景
|
||||
|
||||
调研结果显示,用户主要有以下两种使用意图:
|
||||
|
||||
| 场景 | 描述 | 对应训练子集 |
|
||||
| ---------------- | ------------------------------------ | ------------------------- |
|
||||
| **完全随机生成** | 用户不提供任何条件,直接生成完整地图 | 子集 C(随机墙壁 + Mask) |
|
||||
| **墙壁辅助生成** | 用户手绘墙壁结构,模型填充非墙内容 | 子集 B / D |
|
||||
|
||||
训练策略需针对这两种场景显式建模,通过多子集混合训练使单一模型同时支持两种工作模式。
|
||||
|
||||
---
|
||||
|
||||
## 训练策略
|
||||
|
||||
### 多子集混合训练
|
||||
|
||||
每个 batch 从四个子集中按比例采样,各子集对模型能力的针对性不同:
|
||||
|
||||
#### 子集 A:标准 MaskGIT 训练(比例 0.4 ~ 0.6)
|
||||
|
||||
保留原有 MaskGIT 掩码范式,随机遮盖部分 tile,训练模型基本的补全能力。
|
||||
|
||||
```
|
||||
原始地图 ──► 随机 Mask 部分 tile ──► MaskGIT 预测被遮盖 tile
|
||||
```
|
||||
|
||||
- 掩码策略:沿用现有 `MapMask`(随机掩码 + 分块掩码 + 形态学变换)
|
||||
- 作用:维持模型的基础生成质量,防止针对性训练导致退化
|
||||
|
||||
#### 子集 B:墙壁条件生成(比例 0.1 ~ 0.3)
|
||||
|
||||
去除地图中所有非墙 tile(保留墙壁和空地),让模型在给定墙壁结构下生成其余内容。
|
||||
|
||||
```
|
||||
原始地图 ──► 清除所有非墙 tile(保留 空地/墙壁)──► MaskGIT 生成非墙内容
|
||||
```
|
||||
|
||||
- 输入:仅含墙壁(tile=1)和空地(tile=0)的地图
|
||||
- 目标:模型输出完整的地图(含怪物、道具、门、钥匙等)
|
||||
- 对应场景:**用户手绘墙壁 → 模型填充**
|
||||
|
||||
#### 子集 C:随机条件生成(比例 0.1 ~ 0.3)
|
||||
|
||||
在子集 B 的基础上,进一步对墙壁进行随机 Mask,仅保留部分墙壁作为初始条件,引入随机性以支持完全随机生成场景。
|
||||
|
||||
```
|
||||
原始地图 ──► 清除非墙 tile ──► 随机 Mask 部分墙壁 tile ──► MaskGIT 生成
|
||||
```
|
||||
|
||||
- 输入:随机保留的稀疏墙壁片段
|
||||
- 目标:生成完整地图(含墙壁补全 + 非墙内容填充)
|
||||
- 对应场景:**完全随机生成**(墙壁本身也需要生成)
|
||||
- 注:z 的随机采样在此子集中发挥关键作用,控制生成风格多样性
|
||||
|
||||
#### 子集 D:入口条件生成(比例 0.1 ~ 0.2)
|
||||
|
||||
在子集 B / C 的基础上额外保留入口 tile(tile=10),训练模型在给定入口位置时调整生成结果。
|
||||
|
||||
```
|
||||
原始地图 ──► 清除非墙非入口 tile(保留 空地/墙壁/入口)──► MaskGIT 生成
|
||||
```
|
||||
|
||||
- 输入:墙壁结构 + 入口位置
|
||||
- 目标:生成与入口位置语义一致的完整地图
|
||||
- 对应场景:**用户指定入口 → 模型生成与之匹配的布局**
|
||||
- 注:单独设立此子集的原因是入口对地图整体拓扑结构有较强约束,需要专项强化
|
||||
|
||||
### 各子集比例汇总
|
||||
|
||||
| 子集 | 训练任务 | 建议比例 |
|
||||
| ---- | -------------------------------- | --------- |
|
||||
| A | 标准 MaskGIT(随机掩码补全) | 0.4 ~ 0.6 |
|
||||
| B | 墙壁条件生成(保留全部墙壁) | 0.1 ~ 0.3 |
|
||||
| C | 随机条件生成(随机保留部分墙壁) | 0.1 ~ 0.3 |
|
||||
| D | 入口条件生成(保留墙壁 + 入口) | 0.1 ~ 0.2 |
|
||||
|
||||
各子集比例之和为 1,可作为 `dataset.py` 中采样权重使用。
|
||||
|
||||
### 联合训练流程
|
||||
|
||||
将 VQ-VAE 编码器与改造后的 MaskGIT 一起训练:
|
||||
|
||||
```
|
||||
真实地图 ──► VQ-VAE 编码器 ──► z(离散)
|
||||
│
|
||||
按子集构造的掩码地图 + z ──► MaskGIT ──► 预测 logits ──► 交叉熵损失
|
||||
│
|
||||
▼
|
||||
VQ 损失(commit + uniform)
|
||||
```
|
||||
|
||||
总损失:
|
||||
|
||||
$$\mathcal{L} = \mathcal{L}_{CE}(\text{MaskGIT}) + \beta \cdot \mathcal{L}_{commit} + \gamma \cdot \mathcal{L}_{uniform}$$
|
||||
|
||||
### Dropout z(提升鲁棒性)
|
||||
|
||||
训练时以一定概率(如 10-20%)将 z 替换为随机采样的码字,模拟推理时随机采样 z 的情况,避免模型过度依赖 z 的精确语义。在子集 C(完全随机生成)中此机制尤为重要。
|
||||
|
||||
---
|
||||
|
||||
## 推理流程
|
||||
|
||||
### 场景一:完全随机生成
|
||||
|
||||
```
|
||||
1. 随机从 [0, K) 均匀采样 L 个 index
|
||||
2. 从 codebook 查表,得到 z [L, d_z]
|
||||
3. 初始化地图序列:以少量随机墙壁(比例 0.02 ~ 0.1)作为初始条件,其余位置填充 MASK token
|
||||
4. 迭代 MaskGIT 解码(cosine schedule):
|
||||
a. 输入掩码地图 + z → 预测所有 MASK 位置的 logits
|
||||
b. 取置信度最高的 N 个 token 进行 unmask
|
||||
c. 重复直到无 MASK token
|
||||
5. 输出最终地图
|
||||
```
|
||||
|
||||
**初始墙壁比例说明**:全 MASK 初始化会导致生成多样性不佳(即使使用 top-k 采样也难以改善)。引入少量随机墙壁(2%~10%)作为种子,可有效提升多样性;但比例不宜过高,否则随机放置的墙壁可能产生拓扑结构冲突,影响生成质量。
|
||||
|
||||
### 场景二:墙壁辅助生成(用户手绘墙壁)
|
||||
|
||||
```
|
||||
1. 用户提供墙壁布局(tile=1 的位置),其余位置填充 MASK token
|
||||
2. 随机从 [0, K) 均匀采样 L 个 index,得到 z
|
||||
3. 迭代 MaskGIT 解码:
|
||||
a. 输入带已知墙壁的掩码地图 + z → 预测 MASK 位置的 logits
|
||||
b. 已知墙壁位置不参与 unmask,保持不变
|
||||
c. 取置信度最高的 N 个 MASK token 进行 unmask
|
||||
d. 重复直到无 MASK token
|
||||
4. 输出最终地图
|
||||
```
|
||||
|
||||
### 场景三:入口指定生成
|
||||
|
||||
```
|
||||
与场景二类似,但初始条件同时固定墙壁位置和入口位置(tile=10)
|
||||
其余 MASK 位置由 MaskGIT 迭代解码填充
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 超参数(待确定)
|
||||
|
||||
| 参数 | 说明 | 建议初始值 |
|
||||
| -------------- | --------------------- | ---------- |
|
||||
| `K` | codebook 大小 | 8 ~ 32 |
|
||||
| `L` | 码字序列长度(方案B) | 1 ~ 4 |
|
||||
| `d_z` | 码字嵌入维度 | 64 ~ 128 |
|
||||
| `β` | commit loss 权重 | 0.25 |
|
||||
| `γ` | uniform loss 权重 | 0.1 |
|
||||
| z dropout 概率 | 训练时随机 z 的比例 | 0.1 ~ 0.2 |
|
||||
|
||||
地图固定为 13×13,z 定位为风格与多样性控制信号,K 无需过大(8~32 足够),L 也应保持较小以避免过度约束生成。
|
||||
|
||||
---
|
||||
|
||||
## 已决定事项
|
||||
|
||||
| 决策点 | 结论 |
|
||||
| -------------- | -------------------------------------------------------------------------- |
|
||||
| VQ-VAE 解码器 | **不加**,直接与 MaskGIT 联合训练,z 语义由生成损失反向传播决定 |
|
||||
| 编码器架构方案 | **方案 B**(序列码字),全 Transformer 结构;若训练不稳定可退回方案 A |
|
||||
| z 注入方式 | **Cross-Attention**(方式一) |
|
||||
| 均匀正则化策略 | **熵最大化**;若出现 collapse 再引入 EMA |
|
||||
| 分层 VQ | **单层**,13×13 地图无需多级量化 |
|
||||
| 标量 cond 保留 | **暂不使用**;训练集标量分布严格,推理阶段难以生成合理值,后续可再考虑加入 |
|
||||
|
||||
## 待探索事项
|
||||
|
||||
- 合适的 K、L 取值(建议从 K=16, L=2 开始实验)
|
||||
- z dropout 的最优概率
|
||||
- 若后续 codebook collapse:引入 EMA 更新 + 重置机制
|
||||
- 若后续需要更细粒度控制:加入标量 cond(需对推理侧标量做随机扰动处理)
|
||||
|
||||
---
|
||||
|
||||
## 下一步行动
|
||||
|
||||
- [x] 确定编码方案(方案 B,序列码字)
|
||||
- [x] 确定 z 注入方式(Cross-Attention)
|
||||
- [x] 确定是否需要 VQ-VAE 解码器(不加)
|
||||
- [x] 确定是否保留标量 cond(暂不使用)
|
||||
- [x] 确定训练子集划分与各场景比例
|
||||
- [x] 实现 VQ-VAE 编码器模块(Transformer + VQ)
|
||||
- [ ] 改造 `GinkaMaskGIT.forward()` 接受 z,替换热力图分支
|
||||
- [ ] 实现四种子集采样逻辑(`dataset.py` 新增多子集 Dataset 或采样权重)
|
||||
- [ ] 实现子集 B/C/D 的输入构造函数(按规则清除 tile、保留墙壁/入口)
|
||||
- [ ] 编写联合训练脚本,整合 VQ 损失与 MaskGIT 交叉熵损失
|
||||
202
ginka/dataset.py
202
ginka/dataset.py
@ -229,4 +229,204 @@ class GinkaJointDataset(Dataset):
|
||||
"target_map": target_map,
|
||||
"target_heatmap": target_heatmap,
|
||||
"cond_heatmap": cond_heatmap
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class GinkaVQDataset(Dataset):
|
||||
"""
|
||||
用于 VQ-VAE + MaskGIT 联合训练的多子集数据集。
|
||||
|
||||
每次 __getitem__ 按权重随机选取以下四种子集之一:
|
||||
A (standard): 标准 MaskGIT 随机掩码,随机遮盖部分 tile
|
||||
B (wall-only): 仅保留 wall(1) + floor(0),其余全部替换为 MASK(15)
|
||||
C (wall-random): 在 B 基础上,再随机 mask 部分 wall tile
|
||||
D (wall+entry): 仅保留 wall(1) + floor(0) + entrance(10),其余全部替换为 MASK(15)
|
||||
|
||||
返回 dict:
|
||||
raw_map: LongTensor [H*W] 完整原始地图(供 VQ-VAE 编码)
|
||||
masked_map: LongTensor [H*W] MaskGIT 输入(被 mask 的位置 = 15)
|
||||
target_map: LongTensor [H*W] CE loss ground truth(等同 raw_map)
|
||||
subset: str 子集标识,供调试/统计用
|
||||
"""
|
||||
|
||||
FLOOR = 0
|
||||
WALL = 1
|
||||
ENTRANCE = 10
|
||||
MASK_ID = 15
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path: str,
|
||||
subset_weights: tuple = (0.5, 0.2, 0.2, 0.1),
|
||||
wall_mask_ratio: float = 0.3,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
data_path: JSON 数据文件路径
|
||||
subset_weights: 子集 (A, B, C, D) 的采样权重,自动归一化
|
||||
wall_mask_ratio: Subset C 中额外随机 mask 的 wall tile 比例上限
|
||||
(每次从 [0, wall_mask_ratio] 均匀采样实际比例)
|
||||
"""
|
||||
self.data = load_data(data_path)
|
||||
self.wall_mask_ratio = wall_mask_ratio
|
||||
|
||||
# 累积权重,用于快速随机子集选择
|
||||
total_w = sum(subset_weights)
|
||||
normalized = [x / total_w for x in subset_weights]
|
||||
self.subset_cumw = [sum(normalized[:i + 1]) for i in range(len(normalized))]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 内联随机掩码生成(避免 scipy 的 NumPy 版本兼容问题)
|
||||
# ------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def _sample_mask_ratio(min_r=0.05, max_r=1.0) -> float:
|
||||
"""用 Beta(2,2) 分布采样掩码比例,集中在中间值。"""
|
||||
r = np.random.beta(2, 2)
|
||||
return min_r + (max_r - min_r) * r
|
||||
|
||||
@staticmethod
|
||||
def _random_mask(h: int, w: int) -> np.ndarray:
|
||||
"""纯随机掩码,返回 [H*W] bool。"""
|
||||
ratio = GinkaVQDataset._sample_mask_ratio()
|
||||
total = h * w
|
||||
idx = np.random.choice(total, int(total * ratio), replace=False)
|
||||
mask = np.zeros(total, dtype=bool)
|
||||
mask[idx] = True
|
||||
return mask
|
||||
|
||||
@staticmethod
|
||||
def _block_mask(h: int, w: int) -> np.ndarray:
|
||||
"""矩形分块随机掩码,返回 [H*W] bool。"""
|
||||
ratio = GinkaVQDataset._sample_mask_ratio()
|
||||
max_block = max(2, min(h, w) // 2)
|
||||
target = int(h * w * ratio)
|
||||
mask = np.zeros((h, w), dtype=bool)
|
||||
while mask.sum() < target:
|
||||
bh = np.random.randint(2, max_block + 1)
|
||||
bw = np.random.randint(2, max_block + 1)
|
||||
x = np.random.randint(0, max(1, h - bh + 1))
|
||||
y = np.random.randint(0, max(1, w - bw + 1))
|
||||
mask[x:x + bh, y:y + bw] = True
|
||||
return mask.reshape(-1)
|
||||
|
||||
def _std_mask(self, h: int, w: int) -> np.ndarray:
|
||||
"""标准 MaskGIT 掩码:随机选择纯随机或分块策略。"""
|
||||
if random.random() < 0.5:
|
||||
return self._random_mask(h, w)
|
||||
else:
|
||||
return self._block_mask(h, w)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _augment(self, arr: np.ndarray) -> np.ndarray:
|
||||
"""随机旋转 / 翻转数据增强,返回新 array。"""
|
||||
if np.random.rand() > 0.5:
|
||||
k = np.random.randint(1, 4)
|
||||
arr = np.rot90(arr, k).copy()
|
||||
if np.random.rand() > 0.5:
|
||||
arr = np.fliplr(arr).copy()
|
||||
if np.random.rand() > 0.5:
|
||||
arr = np.flipud(arr).copy()
|
||||
return arr
|
||||
|
||||
def _choose_subset(self) -> str:
|
||||
r = random.random()
|
||||
if r < self.subset_cumw[0]:
|
||||
return 'A'
|
||||
elif r < self.subset_cumw[1]:
|
||||
return 'B'
|
||||
elif r < self.subset_cumw[2]:
|
||||
return 'C'
|
||||
else:
|
||||
return 'D'
|
||||
|
||||
def _apply_subset(self, raw: np.ndarray, subset: str) -> np.ndarray:
|
||||
"""
|
||||
根据子集类型生成 masked_map。
|
||||
|
||||
Args:
|
||||
raw: [H, W] int64 完整原始地图
|
||||
subset: 'A' | 'B' | 'C' | 'D'
|
||||
|
||||
Returns:
|
||||
[H*W] int64,被遮盖位置值为 MASK_ID(15)
|
||||
"""
|
||||
H, W = raw.shape
|
||||
|
||||
if subset == 'A':
|
||||
# 标准随机 mask:纯随机或分块策略
|
||||
mask = self._std_mask(H, W) # [H*W] bool
|
||||
flat = raw.reshape(-1).copy()
|
||||
flat[mask] = self.MASK_ID
|
||||
return flat
|
||||
|
||||
elif subset == 'B':
|
||||
# 仅保留 floor(0) 和 wall(1)
|
||||
flat = raw.reshape(-1).copy()
|
||||
keep = (flat == self.FLOOR) | (flat == self.WALL)
|
||||
flat[~keep] = self.MASK_ID
|
||||
return flat
|
||||
|
||||
elif subset == 'C':
|
||||
# Subset B + 随机 mask 部分 wall
|
||||
flat = raw.reshape(-1).copy()
|
||||
keep = (flat == self.FLOOR) | (flat == self.WALL)
|
||||
flat[~keep] = self.MASK_ID
|
||||
|
||||
wall_idx = np.where(flat == self.WALL)[0]
|
||||
if len(wall_idx) > 0:
|
||||
ratio = random.random() * self.wall_mask_ratio
|
||||
n = max(1, int(len(wall_idx) * ratio))
|
||||
chosen = np.random.choice(wall_idx, n, replace=False)
|
||||
flat[chosen] = self.MASK_ID
|
||||
return flat
|
||||
|
||||
else: # D
|
||||
# 仅保留 floor(0)、wall(1) 和 entrance(10)
|
||||
flat = raw.reshape(-1).copy()
|
||||
keep = (
|
||||
(flat == self.FLOOR)
|
||||
| (flat == self.WALL)
|
||||
| (flat == self.ENTRANCE)
|
||||
)
|
||||
flat[~keep] = self.MASK_ID
|
||||
return flat
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
|
||||
raw_np = self._augment(np.array(item['map'], dtype=np.int64)) # [H, W]
|
||||
subset = self._choose_subset()
|
||||
masked_np = self._apply_subset(raw_np, subset) # [H*W]
|
||||
raw_flat = raw_np.reshape(-1) # [H*W]
|
||||
|
||||
return {
|
||||
"raw_map": torch.LongTensor(raw_flat), # VQ-VAE 编码器输入
|
||||
"masked_map": torch.LongTensor(masked_np), # MaskGIT 输入
|
||||
"target_map": torch.LongTensor(raw_flat.copy()), # CE loss ground truth
|
||||
"subset": subset, # 调试/统计用
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
data_path = os.path.join(os.path.dirname(__file__), '..', 'data', 'ginka-dataset.json')
|
||||
ds = GinkaVQDataset(data_path)
|
||||
print(f"数据集大小: {len(ds)}")
|
||||
|
||||
subset_count = {'A': 0, 'B': 0, 'C': 0, 'D': 0}
|
||||
for i in range(200):
|
||||
sample = ds[i % len(ds)]
|
||||
subset_count[sample['subset']] += 1
|
||||
|
||||
raw = sample['raw_map']
|
||||
masked = sample['masked_map']
|
||||
target = sample['target_map']
|
||||
print(f"raw_map shape={raw.shape}, dtype={raw.dtype}")
|
||||
print(f"masked_map shape={masked.shape}, dtype={masked.dtype}")
|
||||
print(f"target_map shape={target.shape}, dtype={target.dtype}")
|
||||
print(f"被 mask 的位置数: {(masked == 15).sum().item()} / {masked.numel()}")
|
||||
print(f"\n200 次采样子集分布: {subset_count}")
|
||||
@ -15,9 +15,15 @@ class Transformer(nn.Module):
|
||||
num_layers=num_layers
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B, L, d_model]
|
||||
m = self.encoder(x)
|
||||
out = self.decoder(x, m)
|
||||
def forward(self, x, memory=None):
|
||||
# x: [B, S, d_model] 地图 token 序列
|
||||
# memory: [B, L, d_model] 可选的 z 投影,用于 cross-attention
|
||||
# 若 memory 为 None,则退化为原始自编解码行为(向后兼容)
|
||||
enc_out = self.encoder(x)
|
||||
if memory is not None:
|
||||
# encoder 输出作为 query,z 作为 key/value
|
||||
out = self.decoder(enc_out, memory)
|
||||
else:
|
||||
out = self.decoder(x, enc_out)
|
||||
return out
|
||||
|
||||
@ -1,88 +1,117 @@
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..utils import print_memory
|
||||
from .cond import GinkaMaskGITCond
|
||||
from .maskGIT import Transformer
|
||||
|
||||
|
||||
class GinkaMaskGIT(nn.Module):
|
||||
"""
|
||||
改造后的 MaskGIT 地图生成模型。
|
||||
|
||||
以掩码地图序列和 VQ-VAE 输出的离散隐变量 z 为输入,
|
||||
通过 Transformer encoder-decoder 结构预测被遮盖位置的 tile 类别。
|
||||
|
||||
z 通过 cross-attention 注入到 Transformer decoder,
|
||||
作为风格/多样性控制信号,而非结构重建指导。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, num_classes=16, heatmap_channel=4, d_model=256,
|
||||
dim_ff=512, nhead=8, num_layers=4, map_size=13*13
|
||||
self,
|
||||
num_classes: int = 16,
|
||||
d_model: int = 192,
|
||||
d_z: int = 64,
|
||||
dim_ff: int = 512,
|
||||
nhead: int = 8,
|
||||
num_layers: int = 4,
|
||||
map_size: int = 13 * 13,
|
||||
z_dropout: float = 0.1,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
num_classes: tile 类别数(含 MASK token=15)
|
||||
d_model: Transformer 内部维度
|
||||
d_z: VQ-VAE 码字嵌入维度,需与 GinkaVQVAE.d_z 一致
|
||||
dim_ff: 前馈网络隐层维度
|
||||
nhead: 注意力头数
|
||||
num_layers: Transformer 层数
|
||||
map_size: 地图 token 总数(H * W)
|
||||
z_dropout: 训练时随机替换 z 为随机码字的概率(提升鲁棒性)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.z_dropout = z_dropout
|
||||
|
||||
# Tile 嵌入 + 位置编码
|
||||
self.tile_embedding = nn.Embedding(num_classes, d_model)
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model))
|
||||
|
||||
cond_channels = [d_model // 4, d_model // 2, d_model]
|
||||
self.cond_encoder = GinkaMaskGITCond(input_channel=heatmap_channel, channels=cond_channels)
|
||||
self.gate_encoder = nn.Sequential(
|
||||
nn.Conv2d(cond_channels[2], cond_channels[2], 3, padding=1, padding_mode="replicate"),
|
||||
nn.BatchNorm2d(cond_channels[2]),
|
||||
nn.GELU()
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model) * 0.02)
|
||||
|
||||
# z 投影:将 VQ 码字从 d_z 维映射到 d_model 维,供 cross-attention 使用
|
||||
self.z_proj = nn.Sequential(
|
||||
nn.Linear(d_z, d_model),
|
||||
nn.LayerNorm(d_model),
|
||||
)
|
||||
self.cond_gate = nn.Sequential(
|
||||
nn.Linear(cond_channels[2] * 2, cond_channels[2]),
|
||||
nn.LayerNorm(cond_channels[2]),
|
||||
nn.Dropout(0.3),
|
||||
nn.GELU(),
|
||||
|
||||
nn.Linear(cond_channels[2], cond_channels[2])
|
||||
|
||||
# Transformer:encoder 做 map token 自注意力,decoder 做与 z 的 cross-attention
|
||||
self.transformer = Transformer(
|
||||
d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers
|
||||
)
|
||||
|
||||
self.transformer = Transformer(d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers)
|
||||
|
||||
self.output_fc = nn.Sequential(
|
||||
nn.Linear(d_model, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, map: torch.Tensor, heatmap: torch.Tensor):
|
||||
# map: [B, H * W]
|
||||
# heatmap: [B, C, H, W]
|
||||
# output: [B, H * W, num_classes]
|
||||
heatmap = self.cond_encoder(heatmap) # [B, d_model, H, W]
|
||||
B, C, H, W = heatmap.shape
|
||||
heatmap_gate = self.gate_encoder(heatmap)
|
||||
heatmap_mean = F.avg_pool2d(heatmap_gate, (H, W)) # [B, d_model, 1, 1]
|
||||
heatmap_max = F.max_pool2d(heatmap_gate, (H, W)) # [B, d_model, 1, 1]
|
||||
gate_input = torch.cat([heatmap_mean, heatmap_max], dim=1).squeeze(2).squeeze(2)
|
||||
gate = self.cond_gate(gate_input) # [B, d_model]
|
||||
|
||||
heatmap = heatmap * torch.sigmoid(gate).unsqueeze(2).unsqueeze(2)
|
||||
heatmap = heatmap.view(B, C, H * W).permute(0, 2, 1)
|
||||
x = self.tile_embedding(map) + heatmap
|
||||
x = x + self.pos_embedding
|
||||
x = self.transformer(x)
|
||||
|
||||
logits = self.output_fc(x)
|
||||
|
||||
|
||||
self.output_fc = nn.Linear(d_model, num_classes)
|
||||
|
||||
def forward(self, map: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
map: [B, H*W] 掩码后的地图 token 序列(MASK token = 15)
|
||||
z: [B, L, d_z] VQ-VAE 量化后的离散隐变量
|
||||
|
||||
Returns:
|
||||
logits: [B, H*W, num_classes]
|
||||
"""
|
||||
# z dropout:训练时以一定概率将 z 替换为随机均匀噪声,
|
||||
# 模拟推理时随机采样 z 的分布,避免模型过拟合于精确的 z 语义
|
||||
if self.training and self.z_dropout > 0:
|
||||
mask = torch.rand(z.shape[0], 1, 1, device=z.device) < self.z_dropout
|
||||
rand_z = torch.randn_like(z)
|
||||
z = torch.where(mask, rand_z, z)
|
||||
|
||||
# 投影 z 到 d_model 维度
|
||||
z_mem = self.z_proj(z) # [B, L, d_model]
|
||||
|
||||
# tile embedding + 位置编码
|
||||
x = self.tile_embedding(map) # [B, H*W, d_model]
|
||||
x = x + self.pos_embedding # [B, H*W, d_model]
|
||||
|
||||
# Transformer:encoder 做 map 自注意力,decoder cross-attend z
|
||||
x = self.transformer(x, memory=z_mem) # [B, H*W, d_model]
|
||||
|
||||
logits = self.output_fc(x) # [B, H*W, num_classes]
|
||||
return logits
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
device = torch.device("cpu")
|
||||
|
||||
map = torch.randint(0, 16, [1, 169]).to(device)
|
||||
heatmap = torch.rand(1, 4, 13, 13).to(device)
|
||||
|
||||
# 初始化模型
|
||||
model = GinkaMaskGIT().to(device)
|
||||
|
||||
print_memory("初始化后")
|
||||
|
||||
# 前向传播
|
||||
start = time.perf_counter()
|
||||
output = model(map, heatmap)
|
||||
end = time.perf_counter()
|
||||
|
||||
print_memory("前向传播后")
|
||||
|
||||
print(f"推理耗时: {end - start}")
|
||||
print(f"输出形状: output={output.shape}")
|
||||
print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}")
|
||||
print(f"Condition Encoder parameters: {sum(p.numel() for p in model.cond_encoder.parameters())}")
|
||||
print(f"Condition Gate parameters: {sum(p.numel() for p in model.cond_gate.parameters())}")
|
||||
print(f"MaskGIT parameters: {sum(p.numel() for p in model.transformer.parameters())}")
|
||||
print(f"Output parameters: {sum(p.numel() for p in model.output_fc.parameters())}")
|
||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
||||
|
||||
model = GinkaMaskGIT(
|
||||
num_classes=16,
|
||||
d_model=192,
|
||||
d_z=64,
|
||||
dim_ff=512,
|
||||
nhead=8,
|
||||
num_layers=4,
|
||||
map_size=13 * 13,
|
||||
).to(device)
|
||||
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
print(f"总参数量: {total_params:,} ({total_params / 1e6:.3f}M)")
|
||||
for name, module in model.named_children():
|
||||
n = sum(p.numel() for p in module.parameters())
|
||||
print(f" {name}: {n:,}")
|
||||
|
||||
map_input = torch.randint(0, 16, (4, 13 * 13)).to(device) # [B=4, 169]
|
||||
z_input = torch.randn(4, 2, 64).to(device) # [B=4, L=2, d_z=64]
|
||||
|
||||
model.train()
|
||||
logits = model(map_input, z_input)
|
||||
print(f"\nlogits shape: {logits.shape}") # [4, 169, 16]
|
||||
|
||||
print_memory(device, "前向传播后")
|
||||
|
||||
520
ginka/train_vq.py
Normal file
520
ginka/train_vq.py
Normal file
@ -0,0 +1,520 @@
|
||||
"""
|
||||
联合训练脚本:VQ-VAE + MaskGIT
|
||||
|
||||
总损失 = L_CE(MaskGIT 重建损失)+ beta * L_commit + gamma * L_entropy
|
||||
|
||||
验证阶段对四种子集(A/B/C/D)分别输出图片,
|
||||
每条样本额外采样 N_Z_SAMPLES 个随机 z,
|
||||
便于直观对比同条件不同 z 下的生成差异。
|
||||
|
||||
用法示例:
|
||||
python -m ginka.train_vq
|
||||
python -m ginka.train_vq --resume True --state result/joint/joint-10.pth
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .vqvae.model import GinkaVQVAE
|
||||
from .maskGIT.model import GinkaMaskGIT
|
||||
from .dataset import GinkaVQDataset
|
||||
from shared.image import matrix_to_image_cv
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 超参数
|
||||
# ---------------------------------------------------------------------------
|
||||
BATCH_SIZE = 64
|
||||
NUM_CLASSES = 16
|
||||
MASK_TOKEN = 15
|
||||
GENERATE_STEP = 12 # 推理时 MaskGIT 迭代步数
|
||||
MAP_SIZE = 13 * 13
|
||||
MAP_H = MAP_W = 13
|
||||
LABEL_SMOOTHING = 0.0
|
||||
|
||||
# VQ-VAE 超参
|
||||
VQ_L = 2 # summary token 数量(即 z 的序列长度)
|
||||
VQ_K = 16 # codebook 大小
|
||||
VQ_D_Z = 64 # codebook 嵌入维度
|
||||
VQ_D_MODEL= 128
|
||||
VQ_NHEAD = 4
|
||||
VQ_LAYERS = 2
|
||||
VQ_DIM_FF = 256
|
||||
VQ_BETA = 0.25 # commit loss 权重
|
||||
VQ_GAMMA = 0.1 # entropy loss 权重
|
||||
|
||||
# MaskGIT 超参
|
||||
MG_D_MODEL = 192
|
||||
MG_NHEAD = 8
|
||||
MG_LAYERS = 4
|
||||
MG_DIM_FF = 512
|
||||
MG_Z_DROPOUT= 0.15 # 训练时以此概率把 z 替换为随机噪声
|
||||
|
||||
# 验证时对每条样本额外采样的 z 数量(0 = 只用真实 z)
|
||||
N_Z_SAMPLES = 3
|
||||
|
||||
# 四个子集 A/B/C/D 的采样权重(训练集与验证集共用)
|
||||
SUBSET_WEIGHTS = (0.5, 0.2, 0.2, 0.1)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 设备
|
||||
# ---------------------------------------------------------------------------
|
||||
device = torch.device(
|
||||
"cuda:1" if torch.cuda.is_available()
|
||||
else "mps" if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
os.makedirs("result/joint", exist_ok=True)
|
||||
os.makedirs("result/joint_img", exist_ok=True)
|
||||
|
||||
disable_tqdm = not sys.stdout.isatty()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 参数解析
|
||||
# ---------------------------------------------------------------------------
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="VQ-VAE + MaskGIT 联合训练")
|
||||
parser.add_argument("--resume", type=bool, default=False)
|
||||
parser.add_argument("--state", type=str, default="result/joint/joint-10.pth",
|
||||
help="续训时加载的检查点路径")
|
||||
parser.add_argument("--train", type=str, default="data/ginka-dataset.json")
|
||||
parser.add_argument("--validate", type=str, default="data/ginka-eval.json")
|
||||
parser.add_argument("--epochs", type=int, default=100)
|
||||
parser.add_argument("--checkpoint", type=int, default=5,
|
||||
help="每隔多少 epoch 保存检查点并验证")
|
||||
parser.add_argument("--load_optim", type=bool, default=True)
|
||||
return parser.parse_args()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MaskGIT 推理(cosine schedule 迭代解码)
|
||||
# ---------------------------------------------------------------------------
|
||||
@torch.no_grad()
|
||||
def maskgit_generate(
|
||||
model_mg: GinkaMaskGIT,
|
||||
z: torch.Tensor,
|
||||
steps: int = GENERATE_STEP,
|
||||
init_map: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
迭代生成地图(cosine schedule unmasking)。
|
||||
|
||||
Args:
|
||||
model_mg: GinkaMaskGIT(eval 模式)
|
||||
z: [B, L, d_z] 条件 z
|
||||
steps: 解码步数
|
||||
init_map: [B, MAP_SIZE] 可选初始地图;非 MASK 位置在生成过程中保持固定。
|
||||
为 None 时从全 MASK 开始(自由生成)。
|
||||
|
||||
Returns:
|
||||
map_out: [B, MAP_SIZE]
|
||||
"""
|
||||
B = z.shape[0]
|
||||
if init_map is None:
|
||||
map_seq = torch.full((B, MAP_SIZE), MASK_TOKEN, device=device)
|
||||
else:
|
||||
map_seq = init_map.clone().to(device)
|
||||
|
||||
# 记录初始 MASK 位置,这些位置才需要生成
|
||||
generatable = (map_seq == MASK_TOKEN) # [B, S] bool
|
||||
|
||||
for step in range(steps):
|
||||
if not generatable.any():
|
||||
break
|
||||
|
||||
logits = model_mg(map_seq, z) # [B, S, C]
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
dist = torch.distributions.Categorical(probs)
|
||||
sampled = dist.sample() # [B, S]
|
||||
|
||||
# 计算置信度;固定位置设为 +inf(确保不会被选为“继续保持 MASK”)
|
||||
confidences = torch.gather(probs, -1, sampled.unsqueeze(-1)).squeeze(-1)
|
||||
confidences = confidences.masked_fill(~generatable, float('inf'))
|
||||
|
||||
ratio = math.cos(((step + 1) / steps) * math.pi / 2)
|
||||
new_map = map_seq.clone()
|
||||
|
||||
for b in range(B):
|
||||
n_gen = int(generatable[b].sum().item())
|
||||
n_keep = int(ratio * n_gen) # 本步仍保持 MASK 的位置数
|
||||
|
||||
if n_keep > 0:
|
||||
_, keep_idx = torch.topk(confidences[b], k=n_keep, largest=False)
|
||||
pred_b = sampled[b].clone()
|
||||
pred_b[keep_idx] = MASK_TOKEN
|
||||
new_map[b] = torch.where(generatable[b], pred_b, map_seq[b])
|
||||
else:
|
||||
new_map[b] = torch.where(generatable[b], sampled[b], map_seq[b])
|
||||
|
||||
map_seq = new_map
|
||||
|
||||
return map_seq
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 可视化工具
|
||||
# ---------------------------------------------------------------------------
|
||||
def make_map_image(map_flat: torch.Tensor, tile_dict: dict) -> np.ndarray:
|
||||
"""将 [MAP_SIZE] 的 tensor 转成 RGB 图片(numpy)。"""
|
||||
arr = map_flat.cpu().numpy().reshape(MAP_H, MAP_W)
|
||||
return matrix_to_image_cv(arr, tile_dict)
|
||||
|
||||
|
||||
def hstack_images(imgs: list, gap: int = 4, color=(255, 255, 255)) -> np.ndarray:
|
||||
"""将多张等高图片横向拼接,之间插入白色竖线。"""
|
||||
H = imgs[0].shape[0]
|
||||
vline = np.full((H, gap, 3), color, dtype=np.uint8)
|
||||
result = imgs[0]
|
||||
for img in imgs[1:]:
|
||||
result = np.concatenate([result, vline, img], axis=1)
|
||||
return result
|
||||
|
||||
|
||||
def label_image(img: np.ndarray, text: str, font_scale: float = 0.45) -> np.ndarray:
|
||||
"""在图片顶部加一行文字标签(就地修改并返回)。"""
|
||||
bar_h = 16
|
||||
bar = np.full((bar_h, img.shape[1], 3), (40, 40, 40), dtype=np.uint8)
|
||||
cv2.putText(
|
||||
bar, text, (2, bar_h - 3),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, font_scale,
|
||||
(200, 200, 200), 1, cv2.LINE_AA
|
||||
)
|
||||
return np.concatenate([bar, img], axis=0)
|
||||
|
||||
|
||||
def make_random_wall_seed(ratio_min: float = 0.02, ratio_max: float = 0.08) -> torch.Tensor:
|
||||
"""
|
||||
在全 MASK 地图上随机放置少量墙壁作为推理种子,用于完全随机生成场景。
|
||||
|
||||
Returns:
|
||||
[1, MAP_SIZE] MASK=15 背景 + 随机置少量墙壁(tile=1)
|
||||
"""
|
||||
ratio = random.uniform(ratio_min, ratio_max)
|
||||
n_wall = max(2, int(MAP_SIZE * ratio))
|
||||
seed = torch.full((1, MAP_SIZE), MASK_TOKEN, dtype=torch.long, device=device)
|
||||
idx = torch.randperm(MAP_SIZE)[:n_wall]
|
||||
seed[0, idx] = 1 # wall
|
||||
return seed
|
||||
|
||||
@torch.no_grad()
|
||||
def validate(
|
||||
model_vq: GinkaVQVAE,
|
||||
model_mg: GinkaMaskGIT,
|
||||
dataloader_val: DataLoader,
|
||||
tile_dict: dict,
|
||||
epoch: int,
|
||||
):
|
||||
"""
|
||||
验证函数:计算 val loss 并输出 5 类推理场景的对比图。
|
||||
|
||||
场景说明(按 epoch 建立子文件夹,避免图片堆积):
|
||||
场景1 (scene1_completion) : 子集 A,标准随机掩码补全
|
||||
列: ground truth | masked input | z_real pred | z_real gen | z_rand×N
|
||||
场景2 (scene2_wall) : 子集 B,仅墙壁+空地 → 生成完整地图
|
||||
列: ground truth | wall-only input | z_real gen | z_rand×N
|
||||
场景3 (scene3_sparse) : 子集 C,稀疏墙壁条件 → 生成完整地图
|
||||
列: ground truth | sparse wall input | z_real gen | z_rand×N
|
||||
场景4 (scene4_entrance) : 子集 D,墙壁+入口 → 生成完整地图
|
||||
列: ground truth | wall+entrance input | z_real gen | z_rand×N
|
||||
场景5 (scene5_random) : 无数据集参照,随机稀疏墙壁种子 → 完全随机生成
|
||||
列: random seed | z_rand×(N+1)
|
||||
"""
|
||||
model_vq.eval()
|
||||
model_mg.eval()
|
||||
|
||||
# 按 epoch 建立独立子文件夹,保留每次验证结果方便回溯
|
||||
epoch_dir = f"result/joint_img/e{epoch:04d}"
|
||||
os.makedirs(epoch_dir, exist_ok=True)
|
||||
|
||||
val_loss_total = 0.0
|
||||
val_steps = 0
|
||||
captured = {s: None for s in ('A', 'B', 'C', 'D')}
|
||||
|
||||
# ── 计算 val loss + 捕获各子集样本 ──────────────────────────────────────
|
||||
for batch in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm):
|
||||
raw_map = batch["raw_map"].to(device) # [B, 169]
|
||||
masked_map = batch["masked_map"].to(device) # [B, 169]
|
||||
target_map = batch["target_map"].to(device) # [B, 169]
|
||||
subsets = batch["subset"] # list of str
|
||||
B = raw_map.shape[0]
|
||||
|
||||
z_q, _, vq_loss = model_vq(raw_map)
|
||||
logits = model_mg(masked_map, z_q)
|
||||
mask = (masked_map == MASK_TOKEN)
|
||||
|
||||
ce_loss = F.cross_entropy(
|
||||
logits.permute(0, 2, 1), target_map,
|
||||
reduction='none', label_smoothing=LABEL_SMOOTHING
|
||||
)
|
||||
masked_ce = (ce_loss * mask).sum() / (mask.sum() + 1e-6)
|
||||
val_loss_total += (masked_ce + vq_loss).item()
|
||||
val_steps += 1
|
||||
|
||||
for i in range(B):
|
||||
s = subsets[i]
|
||||
if captured[s] is None:
|
||||
captured[s] = {
|
||||
"raw": raw_map[i:i+1].clone(),
|
||||
"masked": masked_map[i:i+1].clone(),
|
||||
"z_q": z_q[i:i+1].clone(),
|
||||
}
|
||||
|
||||
if all(v is not None for v in captured.values()):
|
||||
break
|
||||
|
||||
# ── 公共辅助:对给定条件地图随机采样 n 次 z 并迭代生成 ──────────────────
|
||||
def _rand_gens(cond_map, n):
|
||||
imgs = []
|
||||
for i in range(n):
|
||||
z_r = model_vq.sample(1, device)
|
||||
gen = maskgit_generate(model_mg, z_r, init_map=cond_map)
|
||||
imgs.append(label_image(make_map_image(gen[0], tile_dict), f"z_rand_{i + 1}"))
|
||||
return imgs
|
||||
|
||||
# ── 场景1:标准掩码补全(子集 A)─────────────────────────────────────────
|
||||
if captured['A'] is not None:
|
||||
cap = captured['A']
|
||||
raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q']
|
||||
|
||||
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
|
||||
cond_img = label_image(make_map_image(cond[0], tile_dict), "masked input")
|
||||
|
||||
# 单步 argmax 预测(观察模型对掩码位置的瞬时判断)
|
||||
pred = model_mg(cond, z_q).argmax(dim=-1)[0]
|
||||
pred_img = label_image(make_map_image(pred, tile_dict), "z_real pred")
|
||||
|
||||
# 迭代生成(从掩码输入出发,真实 z)
|
||||
gen_real = maskgit_generate(model_mg, z_q, init_map=cond)
|
||||
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
||||
|
||||
row = [real_img, cond_img, pred_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
||||
cv2.imwrite(f"{epoch_dir}/scene1_completion.png", hstack_images(row))
|
||||
|
||||
# ── 场景2:墙壁辅助生成(子集 B)─────────────────────────────────────────
|
||||
if captured['B'] is not None:
|
||||
cap = captured['B']
|
||||
raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q']
|
||||
|
||||
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
|
||||
cond_img = label_image(make_map_image(cond[0], tile_dict), "wall-only input")
|
||||
gen_real = maskgit_generate(model_mg, z_q, init_map=cond)
|
||||
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
||||
|
||||
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
||||
cv2.imwrite(f"{epoch_dir}/scene2_wall.png", hstack_images(row))
|
||||
|
||||
# ── 场景3:稀疏墙壁条件生成(子集 C)────────────────────────────────────
|
||||
if captured['C'] is not None:
|
||||
cap = captured['C']
|
||||
raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q']
|
||||
|
||||
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
|
||||
cond_img = label_image(make_map_image(cond[0], tile_dict), "sparse wall input")
|
||||
gen_real = maskgit_generate(model_mg, z_q, init_map=cond)
|
||||
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
||||
|
||||
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
||||
cv2.imwrite(f"{epoch_dir}/scene3_sparse.png", hstack_images(row))
|
||||
|
||||
# ── 场景4:墙壁+入口条件生成(子集 D)───────────────────────────────────
|
||||
if captured['D'] is not None:
|
||||
cap = captured['D']
|
||||
raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q']
|
||||
|
||||
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
|
||||
cond_img = label_image(make_map_image(cond[0], tile_dict), "wall+entrance input")
|
||||
gen_real = maskgit_generate(model_mg, z_q, init_map=cond)
|
||||
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
||||
|
||||
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
||||
cv2.imwrite(f"{epoch_dir}/scene4_entrance.png", hstack_images(row))
|
||||
|
||||
# ── 场景5:完全随机生成(无数据集参照)──────────────────────────────────
|
||||
# 随机稀疏墙壁种子,对应推理时"不提供任何条件,直接生成"的场景
|
||||
rand_seed = make_random_wall_seed() # [1, MAP_SIZE]
|
||||
seed_img = label_image(make_map_image(rand_seed[0], tile_dict), "random seed")
|
||||
row = [seed_img] + _rand_gens(rand_seed, N_Z_SAMPLES + 1) # 多采一个 z 展示多样性
|
||||
cv2.imwrite(f"{epoch_dir}/scene5_random.png", hstack_images(row))
|
||||
|
||||
avg_val_loss = val_loss_total / max(val_steps, 1)
|
||||
return avg_val_loss
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 主训练函数
|
||||
# ---------------------------------------------------------------------------
|
||||
def train():
|
||||
print(f"Using device: {device}")
|
||||
args = parse_arguments()
|
||||
|
||||
# ---- 模型 ----
|
||||
model_vq = GinkaVQVAE(
|
||||
num_classes=NUM_CLASSES,
|
||||
L=VQ_L, K=VQ_K, d_z=VQ_D_Z,
|
||||
d_model=VQ_D_MODEL, nhead=VQ_NHEAD,
|
||||
num_layers=VQ_LAYERS, dim_ff=VQ_DIM_FF,
|
||||
map_size=MAP_SIZE,
|
||||
beta=VQ_BETA, gamma=VQ_GAMMA,
|
||||
).to(device)
|
||||
|
||||
model_mg = GinkaMaskGIT(
|
||||
num_classes=NUM_CLASSES,
|
||||
d_model=MG_D_MODEL, d_z=VQ_D_Z,
|
||||
dim_ff=MG_DIM_FF, nhead=MG_NHEAD,
|
||||
num_layers=MG_LAYERS,
|
||||
map_size=MAP_SIZE,
|
||||
z_dropout=MG_Z_DROPOUT,
|
||||
).to(device)
|
||||
|
||||
vq_params = sum(p.numel() for p in model_vq.parameters())
|
||||
mg_params = sum(p.numel() for p in model_mg.parameters())
|
||||
print(f"VQ-VAE 参数量: {vq_params:,} ({vq_params/1e6:.3f}M)")
|
||||
print(f"MaskGIT 参数量: {mg_params:,} ({mg_params/1e6:.3f}M)")
|
||||
print(f"Total 参数量: {vq_params+mg_params:,} ({(vq_params+mg_params)/1e6:.3f}M)")
|
||||
|
||||
# ---- 数据集 ----
|
||||
dataset_train = GinkaVQDataset(
|
||||
args.train,
|
||||
subset_weights=SUBSET_WEIGHTS,
|
||||
)
|
||||
dataset_val = GinkaVQDataset(
|
||||
args.validate,
|
||||
subset_weights=SUBSET_WEIGHTS,
|
||||
)
|
||||
dataloader_train = DataLoader(
|
||||
dataset_train, batch_size=BATCH_SIZE, shuffle=True,
|
||||
num_workers=0, pin_memory=(device.type == "cuda"),
|
||||
)
|
||||
dataloader_val = DataLoader(
|
||||
dataset_val, batch_size=8, shuffle=True,
|
||||
num_workers=0,
|
||||
)
|
||||
|
||||
# ---- 优化器(联合训练,两个模型共用一个 optimizer)----
|
||||
all_params = list(model_vq.parameters()) + list(model_mg.parameters())
|
||||
optimizer = optim.AdamW(all_params, lr=2e-4, weight_decay=1e-2)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer, T_max=args.epochs, eta_min=1e-6
|
||||
)
|
||||
|
||||
# ---- 续训 ----
|
||||
start_epoch = 0
|
||||
if args.resume:
|
||||
ckpt = torch.load(args.state, map_location=device)
|
||||
model_vq.load_state_dict(ckpt["vq_state"], strict=False)
|
||||
model_mg.load_state_dict(ckpt["mg_state"], strict=False)
|
||||
if args.load_optim and ckpt.get("optim_state") is not None:
|
||||
optimizer.load_state_dict(ckpt["optim_state"])
|
||||
start_epoch = ckpt.get("epoch", 0)
|
||||
print(f"从 epoch {start_epoch} 接续训练。")
|
||||
|
||||
# ---- tile 贴图(用于验证可视化)----
|
||||
tile_dict = {}
|
||||
for file in os.listdir("tiles"):
|
||||
name = os.path.splitext(file)[0]
|
||||
img = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
|
||||
if img is not None:
|
||||
tile_dict[name] = img
|
||||
|
||||
# ---- 训练循环 ----
|
||||
for epoch in tqdm(range(start_epoch, start_epoch + args.epochs),
|
||||
desc="Joint Training", disable=disable_tqdm):
|
||||
model_vq.train()
|
||||
model_mg.train()
|
||||
|
||||
loss_total = 0.0
|
||||
ce_total = 0.0
|
||||
vq_loss_total = 0.0
|
||||
subset_stats = {'A': 0, 'B': 0, 'C': 0, 'D': 0}
|
||||
|
||||
for batch in tqdm(dataloader_train, leave=False,
|
||||
desc="Epoch Progress", disable=disable_tqdm):
|
||||
raw_map = batch["raw_map"].to(device) # [B, 169]
|
||||
masked_map = batch["masked_map"].to(device) # [B, 169]
|
||||
target_map = batch["target_map"].to(device) # [B, 169]
|
||||
|
||||
for s in batch["subset"]:
|
||||
subset_stats[s] = subset_stats.get(s, 0) + 1
|
||||
|
||||
# ---- 前向传播 ----
|
||||
# 1. VQ-VAE 编码真实地图 → z_q
|
||||
z_q, _, vq_loss = model_vq(raw_map) # z_q: [B, L, d_z]
|
||||
|
||||
# 2. MaskGIT 以掩码地图 + z 预测原始 tile
|
||||
logits = model_mg(masked_map, z_q) # [B, 169, C]
|
||||
|
||||
# 3. 只对被 mask 的位置计算 CE loss
|
||||
mask = (masked_map == MASK_TOKEN) # [B, 169] bool
|
||||
ce_loss = F.cross_entropy(
|
||||
logits.permute(0, 2, 1), target_map,
|
||||
reduction='none', label_smoothing=LABEL_SMOOTHING
|
||||
)
|
||||
masked_ce = (ce_loss * mask).sum() / (mask.sum() + 1e-6)
|
||||
|
||||
# 4. 联合损失
|
||||
loss = masked_ce + vq_loss
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0)
|
||||
optimizer.step()
|
||||
|
||||
loss_total += loss.detach().item()
|
||||
ce_total += masked_ce.detach().item()
|
||||
vq_loss_total += vq_loss.detach().item()
|
||||
|
||||
scheduler.step()
|
||||
|
||||
n = len(dataloader_train)
|
||||
tqdm.write(
|
||||
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
|
||||
f"Epoch {epoch + 1:4d} | "
|
||||
f"Loss {loss_total/n:.5f} "
|
||||
f"CE {ce_total/n:.5f} "
|
||||
f"VQ {vq_loss_total/n:.5f} | "
|
||||
f"LR {scheduler.get_last_lr()[0]:.6f} | "
|
||||
f"Subsets {subset_stats}"
|
||||
)
|
||||
|
||||
# ---- 检查点 + 验证 ----
|
||||
if (epoch + 1) % args.checkpoint == 0:
|
||||
ckpt_path = f"result/joint/joint-{epoch + 1}.pth"
|
||||
torch.save({
|
||||
"epoch": epoch + 1,
|
||||
"vq_state": model_vq.state_dict(),
|
||||
"mg_state": model_mg.state_dict(),
|
||||
"optim_state":optimizer.state_dict(),
|
||||
}, ckpt_path)
|
||||
tqdm.write(f" 检查点已保存: {ckpt_path}")
|
||||
|
||||
val_loss = validate(
|
||||
model_vq, model_mg, dataloader_val, tile_dict, epoch + 1
|
||||
)
|
||||
tqdm.write(
|
||||
f"[Validate] Epoch {epoch + 1:4d} | Val Loss {val_loss:.5f}"
|
||||
)
|
||||
# 恢复训练模式
|
||||
model_vq.train()
|
||||
model_mg.train()
|
||||
|
||||
print("训练结束。")
|
||||
torch.save({
|
||||
"epoch": start_epoch + args.epochs,
|
||||
"vq_state": model_vq.state_dict(),
|
||||
"mg_state": model_mg.state_dict(),
|
||||
}, "result/joint/joint_final.pth")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(4)
|
||||
train()
|
||||
1
ginka/vqvae/__init__.py
Normal file
1
ginka/vqvae/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .model import GinkaVQVAE
|
||||
186
ginka/vqvae/model.py
Normal file
186
ginka/vqvae/model.py
Normal file
@ -0,0 +1,186 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .quantize import VectorQuantizer
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class GinkaVQVAE(nn.Module):
|
||||
"""
|
||||
VQ-VAE 风格地图编码器。
|
||||
|
||||
将一张完整的地图([B, H*W] 整数 tile ID 序列)编码为 L 个离散码字,
|
||||
输出 z [B, L, d_z] 作为 MaskGIT 模型的生成条件。
|
||||
|
||||
架构:
|
||||
tile embedding + 位置编码
|
||||
→ L 个可学习 summary token(拼接到序列头部)
|
||||
→ Transformer Encoder(Pre-LN,自注意力)
|
||||
→ 取前 L 个输出
|
||||
→ 线性投影到 d_z
|
||||
→ VectorQuantizer(直通估计 + 熵最大化正则)
|
||||
|
||||
设计约束:
|
||||
- 参数量目标 < 1M
|
||||
- 不含解码器,z 的语义由 MaskGIT 端的交叉熵损失间接约束
|
||||
- z 定位为风格/多样性控制信号,而非结构重建指导
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: int = 16,
|
||||
L: int = 2,
|
||||
K: int = 16,
|
||||
d_z: int = 64,
|
||||
d_model: int = 128,
|
||||
nhead: int = 4,
|
||||
num_layers: int = 2,
|
||||
dim_ff: int = 256,
|
||||
map_size: int = 13 * 13,
|
||||
beta: float = 0.25,
|
||||
gamma: float = 0.1,
|
||||
vq_temp: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
num_classes: tile 类别数(含 MASK token)
|
||||
L: 码字序列长度,即 z 的序列维度
|
||||
K: codebook 大小(码字总数)
|
||||
d_z: 码字嵌入维度
|
||||
d_model: Transformer 内部维度
|
||||
nhead: 注意力头数
|
||||
num_layers: Transformer 层数
|
||||
dim_ff: 前馈网络隐层维度
|
||||
map_size: 地图 token 总数(H * W)
|
||||
beta: 承诺损失权重
|
||||
gamma: 熵正则损失权重
|
||||
vq_temp: VQ 软分配 softmax 温度
|
||||
"""
|
||||
super().__init__()
|
||||
self.L = L
|
||||
self.K = K
|
||||
self.d_z = d_z
|
||||
self.beta = beta
|
||||
self.gamma = gamma
|
||||
|
||||
# Tile 嵌入
|
||||
self.tile_embedding = nn.Embedding(num_classes, d_model)
|
||||
|
||||
# 地图位置编码(仅覆盖 map_size 个位置,不含 summary token)
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model) * 0.02)
|
||||
|
||||
# L 个可学习 summary token,拼接到序列头部
|
||||
self.summary_tokens = nn.Parameter(torch.randn(1, L, d_model) * 0.02)
|
||||
|
||||
# Pre-LN Transformer Encoder(训练更稳定)
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
dim_feedforward=dim_ff,
|
||||
batch_first=True,
|
||||
activation='gelu',
|
||||
norm_first=True, # Pre-LN
|
||||
dropout=0.0,
|
||||
)
|
||||
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
||||
|
||||
# 将 Transformer 输出投影到 codebook 维度 d_z
|
||||
self.proj = nn.Sequential(
|
||||
nn.Linear(d_model, d_z),
|
||||
nn.LayerNorm(d_z),
|
||||
)
|
||||
|
||||
# 向量量化层
|
||||
self.vq = VectorQuantizer(K=K, d_z=d_z, temp=vq_temp)
|
||||
|
||||
def encode(self, map: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
将地图编码为量化前的连续向量序列。
|
||||
|
||||
Args:
|
||||
map: [B, H*W] 整数 tile ID
|
||||
|
||||
Returns:
|
||||
z_e: [B, L, d_z] 量化前的编码向量
|
||||
"""
|
||||
B = map.shape[0]
|
||||
|
||||
x = self.tile_embedding(map) # [B, H*W, d_model]
|
||||
x = x + self.pos_embedding # [B, H*W, d_model]
|
||||
|
||||
summary = self.summary_tokens.expand(B, -1, -1) # [B, L, d_model]
|
||||
x = torch.cat([summary, x], dim=1) # [B, L+H*W, d_model]
|
||||
|
||||
x = self.transformer(x) # [B, L+H*W, d_model]
|
||||
|
||||
z_e = self.proj(x[:, :self.L]) # [B, L, d_z]
|
||||
return z_e
|
||||
|
||||
def forward(self, map: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
完整前向传播:编码 → 量化 → 计算损失。
|
||||
|
||||
Args:
|
||||
map: [B, H*W] 整数 tile ID(训练时传入完整真实地图)
|
||||
|
||||
Returns:
|
||||
z_q: [B, L, d_z] 量化后的 z(含直通梯度),供 MaskGIT 使用
|
||||
indices: [B, L] 每个位置对应的码字索引
|
||||
vq_loss: scalar VQ 总损失 = beta * commit_loss + gamma * entropy_loss
|
||||
"""
|
||||
z_e = self.encode(map)
|
||||
z_q, indices, commit_loss, entropy_loss = self.vq(z_e)
|
||||
|
||||
vq_loss = self.beta * commit_loss + self.gamma * entropy_loss
|
||||
return z_q, indices, vq_loss
|
||||
|
||||
def sample(self, B: int, device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
推理阶段:从 codebook 中随机均匀采样 L 个码字。
|
||||
|
||||
Args:
|
||||
B: batch size
|
||||
device: 目标设备
|
||||
|
||||
Returns:
|
||||
z: [B, L, d_z]
|
||||
"""
|
||||
indices = torch.randint(0, self.K, (B, self.L), device=device)
|
||||
z = self.vq.codebook(indices) # [B, L, d_z]
|
||||
return z
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
device = torch.device("cpu")
|
||||
|
||||
model = GinkaVQVAE(
|
||||
num_classes=16,
|
||||
L=2,
|
||||
K=16,
|
||||
d_z=64,
|
||||
d_model=128,
|
||||
nhead=4,
|
||||
num_layers=2,
|
||||
dim_ff=256,
|
||||
map_size=13 * 13,
|
||||
).to(device)
|
||||
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
print(f"总参数量: {total_params:,} ({total_params / 1e6:.3f}M)")
|
||||
|
||||
# 分模块参数统计
|
||||
for name, module in model.named_children():
|
||||
n = sum(p.numel() for p in module.parameters())
|
||||
print(f" {name}: {n:,}")
|
||||
|
||||
# 前向传播测试
|
||||
map_input = torch.randint(0, 15, (4, 13 * 13)).to(device) # [B=4, 169]
|
||||
|
||||
z_q, indices, vq_loss = model(map_input)
|
||||
|
||||
print(f"\nz_q shape: {z_q.shape}") # [4, 2, 64]
|
||||
print(f"indices shape:{indices.shape}") # [4, 2]
|
||||
print(f"vq_loss: {vq_loss.item():.4f}")
|
||||
|
||||
# 推理采样测试
|
||||
z_sample = model.sample(B=4, device=device)
|
||||
print(f"sample shape: {z_sample.shape}") # [4, 2, 64]
|
||||
81
ginka/vqvae/quantize.py
Normal file
81
ginka/vqvae/quantize.py
Normal file
@ -0,0 +1,81 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class VectorQuantizer(nn.Module):
|
||||
"""
|
||||
向量量化层(Vector Quantization)。
|
||||
|
||||
将连续的编码向量序列映射到离散的 codebook 码字索引,
|
||||
并通过直通估计(Straight-Through Estimator)保持梯度流。
|
||||
|
||||
均匀分布正则化采用软分配熵最大化方案:
|
||||
通过对距离做 softmax 得到软分配概率,计算平均码字使用率的熵,
|
||||
最小化负熵以鼓励所有码字被均等使用。
|
||||
"""
|
||||
|
||||
def __init__(self, K: int, d_z: int, temp: float = 1.0):
|
||||
"""
|
||||
Args:
|
||||
K: codebook 大小(码字数量)
|
||||
d_z: 码字嵌入维度
|
||||
temp: 软分配 softmax 温度,越小越接近 hard assignment
|
||||
"""
|
||||
super().__init__()
|
||||
self.K = K
|
||||
self.d_z = d_z
|
||||
self.temp = temp
|
||||
|
||||
self.codebook = nn.Embedding(K, d_z)
|
||||
nn.init.uniform_(self.codebook.weight, -1.0 / K, 1.0 / K)
|
||||
|
||||
def forward(self, z_e: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
z_e: [B, L, d_z] 编码器输出的连续向量序列
|
||||
|
||||
Returns:
|
||||
z_q_st: [B, L, d_z] 量化后向量(直通梯度)
|
||||
indices: [B, L] 每个位置对应的码字索引
|
||||
commit_loss: scalar 承诺损失 ||z_e - sg(z_q)||^2
|
||||
entropy_loss: scalar 负熵损失(最小化 = 最大化码字使用均匀度)
|
||||
"""
|
||||
B, L, d_z = z_e.shape
|
||||
|
||||
# 展平到 [B*L, d_z]
|
||||
z_flat = z_e.reshape(B * L, d_z)
|
||||
|
||||
codebook_w = self.codebook.weight # [K, d_z]
|
||||
|
||||
# 计算 L2 距离:||z_e - e_k||^2 = ||z_e||^2 + ||e_k||^2 - 2 * z_e · e_k
|
||||
# distances: [B*L, K]
|
||||
distances = (
|
||||
(z_flat ** 2).sum(dim=1, keepdim=True) # [B*L, 1]
|
||||
+ (codebook_w ** 2).sum(dim=1) # [K]
|
||||
- 2.0 * z_flat @ codebook_w.t() # [B*L, K]
|
||||
)
|
||||
|
||||
# Hard assignment:取最近码字索引
|
||||
indices = distances.argmin(dim=1) # [B*L]
|
||||
|
||||
# 量化向量
|
||||
z_q_flat = self.codebook(indices) # [B*L, d_z]
|
||||
z_q = z_q_flat.reshape(B, L, d_z)
|
||||
|
||||
# 直通估计:前向传 z_q,反向传 z_e 的梯度
|
||||
z_q_st = z_e + (z_q - z_e).detach()
|
||||
|
||||
# 承诺损失:拉近编码向量与其对应的码字(仅更新编码器)
|
||||
commit_loss = F.mse_loss(z_e, z_q.detach())
|
||||
|
||||
# 熵最大化正则:通过软分配计算平均码字使用率,最小化负熵
|
||||
# soft_assign: [B*L, K],对距离做 softmax(距离越小,概率越大)
|
||||
soft_assign = F.softmax(-distances / self.temp, dim=1)
|
||||
avg_assign = soft_assign.mean(dim=0) # [K],平均码字使用率
|
||||
# entropy_loss = -H(p) = sum(p * log(p)),最小化即最大化熵
|
||||
entropy_loss = (avg_assign * torch.log(avg_assign + 1e-10)).sum()
|
||||
|
||||
indices = indices.reshape(B, L)
|
||||
return z_q_st, indices, commit_loss, entropy_loss
|
||||
Loading…
Reference in New Issue
Block a user