refactor: 采用 VQ + MaskGIT 方案

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
unanmed 2026-04-26 23:45:56 +08:00
parent 1eda704986
commit 068940cae0
8 changed files with 1497 additions and 77 deletions

View File

@ -0,0 +1,397 @@
# Ginka 地图生成模型重构设计文档
## 背景与动机
### 现状与问题
当前模型采用**热力图Heatmap**作为生成条件,引导 MaskGIT 模型生成地图。实践中发现以下问题:
1. **模型侧**MaskGIT 无法有效利用热力图信息,热力图对生成结果的控制力弱;
2. **用户侧**:用户难以手动构造语义合理的热力图,交互成本高;
3. **表达侧**:热力图是高维连续信号,携带冗余信息,但实际指导生成的有效信息量有限。
### 改进目标
- 保留 MaskGIT 的地图生成范式,但将生成条件改为**低维离散隐变量 z**
- z 由 **VQ-VAE 风格的编码器**从真实地图中抽取,形成一个可复用的 **codebook**
- 训练期间通过 codebook 为 MaskGIT 提供多样性控制信号;
- 推理期间**直接随机采样 z**,无需用户提供任何额外输入。
---
## 整体架构
系统由两个子模型组成,训练可以分阶段进行,也可以联合训练。
```
┌─────────────────────────────────────────────────────────┐
│ 训练阶段 │
│ │
│ 真实地图 ──► [VQ-VAE 编码器] ──► z离散码序列
│ │ │
│ ▼ │
│ [Codebook] │
│ │ │
│ ▼ │
│ 掩码地图 + z ──► [MaskGIT] ──► 预测被遮盖的图块 │
└─────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────┐
│ 推理阶段 │
│ │
│ 随机采样 z ──► [MaskGIT 迭代解码] ──► 生成地图 │
└─────────────────────────────────────────────────────────┘
```
---
## 子模型一VQ-VAE 风格编码器
**规模约束**:编码器参数量应控制在 **< 1M**。此前 MaskGIT 在热力图方案中已验证 4M 模型处于欠拟合与过拟合的边界改用新方案整体规模不会有大幅变化VQ-VAE 编码器作为辅助模块无需过大
### 职责
将一张完整的地图(`[H, W]` 整数矩阵)编码为一个**离散的 z**z 是从 codebook 中查得的嵌入向量(或向量序列)。
### 设计方案
编码器统一采用**纯 Transformer 结构**地图尺寸13×13不大Transformer 不会带来过大计算开销,同时更擅长捕捉图块之间的长程依赖,优于 CNN 方案。
#### 方案 A全图单一码字z 为标量 index
- 编码器将整张地图压缩为一个 `d_z` 维向量,再量化为 codebook 中的单个码字;
- z 是一个整数 index对应 codebook 中一行嵌入;
- 结构最简单,但表达能力有限,难以捕捉地图的局部多样性。
```
地图 [H, W] ──► Transformer 编码 ──► [d_z] ──► 量化 ──► index标量
codebook[index] → z [d_z]
```
#### 方案 B空间码字序列z 为序列,**已选定**
- 编码器将地图编码为 `[L, d_z]` 的特征序列L 为码字数量);
- 每个位置独立量化为 codebook 中的一个码字;
- z 是 `L` 个 index 的组合,表达能力更强;
- 推理时随机采样 L 个 index 即可。
```
地图 [H, W] ──► Transformer 编码 ──► [L, d_z] ──► 逐位量化 ──► [L 个 index]
codebook[index_0..L] → z [L, d_z]
```
### VQ-VAE 量化机制
标准向量量化Vector Quantization
$$z_q = \arg\min_{e_k \in \mathcal{E}} \| z_e - e_k \|_2$$
其中 $\mathcal{E} = \{e_1, ..., e_K\}$ 为 codebook$K$ 为码本大小。
**直通估计Straight-Through Estimator**用于反向传播:
$$\frac{\partial \mathcal{L}}{\partial z_e} \approx \frac{\partial \mathcal{L}}{\partial z_q}$$
### 损失函数设计
总损失由三部分组成:
$$\mathcal{L}_{VQVAE} = \mathcal{L}_{recon} + \beta \cdot \mathcal{L}_{commit} + \gamma \cdot \mathcal{L}_{uniform}$$
| 损失项 | 说明 |
| ----------------------- | ----------------------------------------------------------- |
| $\mathcal{L}_{recon}$ | 重建损失,确保 z 能恢复地图信息(可选,视是否加解码器而定) |
| $\mathcal{L}_{commit}$ | 承诺损失 $\|z_e - \text{sg}(z_q)\|^2$,拉近编码向量与码字 |
| $\mathcal{L}_{uniform}$ | **均匀分布正则化**,鼓励各码字被均等使用(见下节) |
#### 均匀分布正则化
目标:让所有 `K` 个码字在训练中被均等使用,避免 codebook collapse。
方案一:**熵最大化**
$$\mathcal{L}_{uniform} = -H(p) = \sum_{k=1}^{K} p_k \log p_k$$
其中 $p_k$ 为码字 $k$ 在当前 batch 中被选中的频率(使用 EMA 估计)。
方案二:**EMA 更新 + 重置**
- 使用指数移动平均EMA更新 codebook不通过梯度更新
- 定期检测使用率低于阈值的码字,将其重置为当前 batch 中随机样本点。
方案三:**codebook 使用 KL 散度**
$$\mathcal{L}_{uniform} = \text{KL}(p \| \mathcal{U}_{K})$$
其中 $\mathcal{U}_K$ 为均匀分布,即期望每个码字被选中概率为 $1/K$。
**初始方案**:采用**熵最大化**(方案一),通过梯度直接优化,实现简单;若后续出现 codebook collapse 问题,再引入 EMA 更新 + 重置机制。
### 是否需要解码器
| | 有解码器 | 无解码器 |
| ---------- | -------------------------- | ----------------------------------------- |
| 重建损失 | 有z 需要还原地图 | 无z 的质量由 MaskGIT 的生成质量间接约束 |
| 训练复杂度 | 较高 | 较低 |
| z 语义性 | 强z 明确包含地图结构信息 | 弱z 更像风格/多样性控制 |
| 推荐场景 | 需要 z 携带结构语义 | 仅用于多样性控制 |
**已确认:不加解码器**,直接与 MaskGIT 联合训练。z 的语义由 MaskGIT 端的生成损失反向传播决定。z 定位为风格与多样性控制信号,而非结构重建指导——若 z 能完整重建地图,模型会忽略其他条件直接依赖 z反而降低可控性与泛化能力。
相应地,损失函数简化为:
$$\mathcal{L} = \mathcal{L}_{CE}(\text{MaskGIT}) + \beta \cdot \mathcal{L}_{commit} + \gamma \cdot \mathcal{L}_{uniform}$$
---
## 子模型二MaskGIT 改造
### 现有输入
```python
def forward(self, map: torch.Tensor, heatmap: torch.Tensor):
# map: [B, H * W]
# heatmap: [B, C, H, W]
```
### 改造后输入
```python
def forward(self, map: torch.Tensor, z: torch.Tensor):
# map: [B, H * W] — 掩码后的地图 token 序列
# z: [B, L, d_z] — 离散隐变量方案B或 [B, d_z]方案A
```
移除原有的 `GinkaMaskGITCond`(热力图 CNN 编码器)和 `gate_encoder`,用 z 替代其功能。
### z 注入 MaskGIT 的方式(待选)
#### 方式一Cross-Attention推荐
Transformer 解码器天然支持 cross-attention
```
map token sequence ──► self-attention ──► cross-attention ◄── z
预测 logits
```
z 作为 memory 输入到 `TransformerDecoder`map token 序列作为 query。
优点z 可以是任意长度的序列方案B友好表达力强。**已选定此方案。**
#### 方式二z 拼接到序列头部Prefix Token
将 z 的每个码字嵌入作为前缀 token拼接到 map token 序列前:
```
[z_0, z_1, ..., z_L, tile_0, tile_1, ..., tile_{H*W}]
```
输入到 TransformerEncoder 做自注意力,位置编码需相应扩展。
优点:实现简单,无需 encoder-decoder 分离。
#### 方式三Add/FiLM 调制
将 z 投影后,用 FiLMFeature-wise Linear Modulation调制每一层的特征
$$y = \gamma(z) \odot x + \beta(z)$$
优点参数高效缺点z 必须先聚合为单向量适合方案A
---
## 用户使用场景
调研结果显示,用户主要有以下两种使用意图:
| 场景 | 描述 | 对应训练子集 |
| ---------------- | ------------------------------------ | ------------------------- |
| **完全随机生成** | 用户不提供任何条件,直接生成完整地图 | 子集 C随机墙壁 + Mask |
| **墙壁辅助生成** | 用户手绘墙壁结构,模型填充非墙内容 | 子集 B / D |
训练策略需针对这两种场景显式建模,通过多子集混合训练使单一模型同时支持两种工作模式。
---
## 训练策略
### 多子集混合训练
每个 batch 从四个子集中按比例采样,各子集对模型能力的针对性不同:
#### 子集 A标准 MaskGIT 训练(比例 0.4 ~ 0.6
保留原有 MaskGIT 掩码范式,随机遮盖部分 tile训练模型基本的补全能力。
```
原始地图 ──► 随机 Mask 部分 tile ──► MaskGIT 预测被遮盖 tile
```
- 掩码策略:沿用现有 `MapMask`(随机掩码 + 分块掩码 + 形态学变换)
- 作用:维持模型的基础生成质量,防止针对性训练导致退化
#### 子集 B墙壁条件生成比例 0.1 ~ 0.3
去除地图中所有非墙 tile保留墙壁和空地让模型在给定墙壁结构下生成其余内容。
```
原始地图 ──► 清除所有非墙 tile保留 空地/墙壁)──► MaskGIT 生成非墙内容
```
- 输入仅含墙壁tile=1和空地tile=0的地图
- 目标:模型输出完整的地图(含怪物、道具、门、钥匙等)
- 对应场景:**用户手绘墙壁 → 模型填充**
#### 子集 C随机条件生成比例 0.1 ~ 0.3
在子集 B 的基础上,进一步对墙壁进行随机 Mask仅保留部分墙壁作为初始条件引入随机性以支持完全随机生成场景。
```
原始地图 ──► 清除非墙 tile ──► 随机 Mask 部分墙壁 tile ──► MaskGIT 生成
```
- 输入:随机保留的稀疏墙壁片段
- 目标:生成完整地图(含墙壁补全 + 非墙内容填充)
- 对应场景:**完全随机生成**(墙壁本身也需要生成)
- 注z 的随机采样在此子集中发挥关键作用,控制生成风格多样性
#### 子集 D入口条件生成比例 0.1 ~ 0.2
在子集 B / C 的基础上额外保留入口 tiletile=10训练模型在给定入口位置时调整生成结果。
```
原始地图 ──► 清除非墙非入口 tile保留 空地/墙壁/入口)──► MaskGIT 生成
```
- 输入:墙壁结构 + 入口位置
- 目标:生成与入口位置语义一致的完整地图
- 对应场景:**用户指定入口 → 模型生成与之匹配的布局**
- 注:单独设立此子集的原因是入口对地图整体拓扑结构有较强约束,需要专项强化
### 各子集比例汇总
| 子集 | 训练任务 | 建议比例 |
| ---- | -------------------------------- | --------- |
| A | 标准 MaskGIT随机掩码补全 | 0.4 ~ 0.6 |
| B | 墙壁条件生成(保留全部墙壁) | 0.1 ~ 0.3 |
| C | 随机条件生成(随机保留部分墙壁) | 0.1 ~ 0.3 |
| D | 入口条件生成(保留墙壁 + 入口) | 0.1 ~ 0.2 |
各子集比例之和为 1可作为 `dataset.py` 中采样权重使用。
### 联合训练流程
将 VQ-VAE 编码器与改造后的 MaskGIT 一起训练:
```
真实地图 ──► VQ-VAE 编码器 ──► z离散
按子集构造的掩码地图 + z ──► MaskGIT ──► 预测 logits ──► 交叉熵损失
VQ 损失commit + uniform
```
总损失:
$$\mathcal{L} = \mathcal{L}_{CE}(\text{MaskGIT}) + \beta \cdot \mathcal{L}_{commit} + \gamma \cdot \mathcal{L}_{uniform}$$
### Dropout z提升鲁棒性
训练时以一定概率(如 10-20%)将 z 替换为随机采样的码字,模拟推理时随机采样 z 的情况,避免模型过度依赖 z 的精确语义。在子集 C完全随机生成中此机制尤为重要。
---
## 推理流程
### 场景一:完全随机生成
```
1. 随机从 [0, K) 均匀采样 L 个 index
2. 从 codebook 查表,得到 z [L, d_z]
3. 初始化地图序列:以少量随机墙壁(比例 0.02 ~ 0.1)作为初始条件,其余位置填充 MASK token
4. 迭代 MaskGIT 解码cosine schedule
a. 输入掩码地图 + z → 预测所有 MASK 位置的 logits
b. 取置信度最高的 N 个 token 进行 unmask
c. 重复直到无 MASK token
5. 输出最终地图
```
**初始墙壁比例说明**:全 MASK 初始化会导致生成多样性不佳(即使使用 top-k 采样也难以改善。引入少量随机墙壁2%~10%)作为种子,可有效提升多样性;但比例不宜过高,否则随机放置的墙壁可能产生拓扑结构冲突,影响生成质量。
### 场景二:墙壁辅助生成(用户手绘墙壁)
```
1. 用户提供墙壁布局tile=1 的位置),其余位置填充 MASK token
2. 随机从 [0, K) 均匀采样 L 个 index得到 z
3. 迭代 MaskGIT 解码:
a. 输入带已知墙壁的掩码地图 + z → 预测 MASK 位置的 logits
b. 已知墙壁位置不参与 unmask保持不变
c. 取置信度最高的 N 个 MASK token 进行 unmask
d. 重复直到无 MASK token
4. 输出最终地图
```
### 场景三:入口指定生成
```
与场景二类似但初始条件同时固定墙壁位置和入口位置tile=10
其余 MASK 位置由 MaskGIT 迭代解码填充
```
---
## 超参数(待确定)
| 参数 | 说明 | 建议初始值 |
| -------------- | --------------------- | ---------- |
| `K` | codebook 大小 | 8 ~ 32 |
| `L` | 码字序列长度方案B | 1 ~ 4 |
| `d_z` | 码字嵌入维度 | 64 ~ 128 |
| `β` | commit loss 权重 | 0.25 |
| `γ` | uniform loss 权重 | 0.1 |
| z dropout 概率 | 训练时随机 z 的比例 | 0.1 ~ 0.2 |
地图固定为 13×13z 定位为风格与多样性控制信号K 无需过大8~32 足够L 也应保持较小以避免过度约束生成。
---
## 已决定事项
| 决策点 | 结论 |
| -------------- | -------------------------------------------------------------------------- |
| VQ-VAE 解码器 | **不加**,直接与 MaskGIT 联合训练z 语义由生成损失反向传播决定 |
| 编码器架构方案 | **方案 B**(序列码字),全 Transformer 结构;若训练不稳定可退回方案 A |
| z 注入方式 | **Cross-Attention**(方式一) |
| 均匀正则化策略 | **熵最大化**;若出现 collapse 再引入 EMA |
| 分层 VQ | **单层**13×13 地图无需多级量化 |
| 标量 cond 保留 | **暂不使用**;训练集标量分布严格,推理阶段难以生成合理值,后续可再考虑加入 |
## 待探索事项
- 合适的 K、L 取值(建议从 K=16, L=2 开始实验)
- z dropout 的最优概率
- 若后续 codebook collapse引入 EMA 更新 + 重置机制
- 若后续需要更细粒度控制:加入标量 cond需对推理侧标量做随机扰动处理
---
## 下一步行动
- [x] 确定编码方案(方案 B序列码字
- [x] 确定 z 注入方式Cross-Attention
- [x] 确定是否需要 VQ-VAE 解码器(不加)
- [x] 确定是否保留标量 cond暂不使用
- [x] 确定训练子集划分与各场景比例
- [x] 实现 VQ-VAE 编码器模块Transformer + VQ
- [ ] 改造 `GinkaMaskGIT.forward()` 接受 z替换热力图分支
- [ ] 实现四种子集采样逻辑(`dataset.py` 新增多子集 Dataset 或采样权重)
- [ ] 实现子集 B/C/D 的输入构造函数(按规则清除 tile、保留墙壁/入口)
- [ ] 编写联合训练脚本,整合 VQ 损失与 MaskGIT 交叉熵损失

View File

@ -229,4 +229,204 @@ class GinkaJointDataset(Dataset):
"target_map": target_map,
"target_heatmap": target_heatmap,
"cond_heatmap": cond_heatmap
}
}
class GinkaVQDataset(Dataset):
"""
用于 VQ-VAE + MaskGIT 联合训练的多子集数据集
每次 __getitem__ 按权重随机选取以下四种子集之一
A (standard): 标准 MaskGIT 随机掩码随机遮盖部分 tile
B (wall-only): 仅保留 wall(1) + floor(0)其余全部替换为 MASK(15)
C (wall-random): B 基础上再随机 mask 部分 wall tile
D (wall+entry): 仅保留 wall(1) + floor(0) + entrance(10)其余全部替换为 MASK(15)
返回 dict:
raw_map: LongTensor [H*W] 完整原始地图 VQ-VAE 编码
masked_map: LongTensor [H*W] MaskGIT 输入 mask 的位置 = 15
target_map: LongTensor [H*W] CE loss ground truth等同 raw_map
subset: str 子集标识供调试/统计用
"""
FLOOR = 0
WALL = 1
ENTRANCE = 10
MASK_ID = 15
def __init__(
self,
data_path: str,
subset_weights: tuple = (0.5, 0.2, 0.2, 0.1),
wall_mask_ratio: float = 0.3,
):
"""
Args:
data_path: JSON 数据文件路径
subset_weights: 子集 (A, B, C, D) 的采样权重自动归一化
wall_mask_ratio: Subset C 中额外随机 mask wall tile 比例上限
每次从 [0, wall_mask_ratio] 均匀采样实际比例
"""
self.data = load_data(data_path)
self.wall_mask_ratio = wall_mask_ratio
# 累积权重,用于快速随机子集选择
total_w = sum(subset_weights)
normalized = [x / total_w for x in subset_weights]
self.subset_cumw = [sum(normalized[:i + 1]) for i in range(len(normalized))]
def __len__(self):
return len(self.data)
# ------------------------------------------------------------------
# 内联随机掩码生成(避免 scipy 的 NumPy 版本兼容问题)
# ------------------------------------------------------------------
@staticmethod
def _sample_mask_ratio(min_r=0.05, max_r=1.0) -> float:
"""用 Beta(2,2) 分布采样掩码比例,集中在中间值。"""
r = np.random.beta(2, 2)
return min_r + (max_r - min_r) * r
@staticmethod
def _random_mask(h: int, w: int) -> np.ndarray:
"""纯随机掩码,返回 [H*W] bool。"""
ratio = GinkaVQDataset._sample_mask_ratio()
total = h * w
idx = np.random.choice(total, int(total * ratio), replace=False)
mask = np.zeros(total, dtype=bool)
mask[idx] = True
return mask
@staticmethod
def _block_mask(h: int, w: int) -> np.ndarray:
"""矩形分块随机掩码,返回 [H*W] bool。"""
ratio = GinkaVQDataset._sample_mask_ratio()
max_block = max(2, min(h, w) // 2)
target = int(h * w * ratio)
mask = np.zeros((h, w), dtype=bool)
while mask.sum() < target:
bh = np.random.randint(2, max_block + 1)
bw = np.random.randint(2, max_block + 1)
x = np.random.randint(0, max(1, h - bh + 1))
y = np.random.randint(0, max(1, w - bw + 1))
mask[x:x + bh, y:y + bw] = True
return mask.reshape(-1)
def _std_mask(self, h: int, w: int) -> np.ndarray:
"""标准 MaskGIT 掩码:随机选择纯随机或分块策略。"""
if random.random() < 0.5:
return self._random_mask(h, w)
else:
return self._block_mask(h, w)
# ------------------------------------------------------------------
def _augment(self, arr: np.ndarray) -> np.ndarray:
"""随机旋转 / 翻转数据增强,返回新 array。"""
if np.random.rand() > 0.5:
k = np.random.randint(1, 4)
arr = np.rot90(arr, k).copy()
if np.random.rand() > 0.5:
arr = np.fliplr(arr).copy()
if np.random.rand() > 0.5:
arr = np.flipud(arr).copy()
return arr
def _choose_subset(self) -> str:
r = random.random()
if r < self.subset_cumw[0]:
return 'A'
elif r < self.subset_cumw[1]:
return 'B'
elif r < self.subset_cumw[2]:
return 'C'
else:
return 'D'
def _apply_subset(self, raw: np.ndarray, subset: str) -> np.ndarray:
"""
根据子集类型生成 masked_map
Args:
raw: [H, W] int64 完整原始地图
subset: 'A' | 'B' | 'C' | 'D'
Returns:
[H*W] int64被遮盖位置值为 MASK_ID(15)
"""
H, W = raw.shape
if subset == 'A':
# 标准随机 mask纯随机或分块策略
mask = self._std_mask(H, W) # [H*W] bool
flat = raw.reshape(-1).copy()
flat[mask] = self.MASK_ID
return flat
elif subset == 'B':
# 仅保留 floor(0) 和 wall(1)
flat = raw.reshape(-1).copy()
keep = (flat == self.FLOOR) | (flat == self.WALL)
flat[~keep] = self.MASK_ID
return flat
elif subset == 'C':
# Subset B + 随机 mask 部分 wall
flat = raw.reshape(-1).copy()
keep = (flat == self.FLOOR) | (flat == self.WALL)
flat[~keep] = self.MASK_ID
wall_idx = np.where(flat == self.WALL)[0]
if len(wall_idx) > 0:
ratio = random.random() * self.wall_mask_ratio
n = max(1, int(len(wall_idx) * ratio))
chosen = np.random.choice(wall_idx, n, replace=False)
flat[chosen] = self.MASK_ID
return flat
else: # D
# 仅保留 floor(0)、wall(1) 和 entrance(10)
flat = raw.reshape(-1).copy()
keep = (
(flat == self.FLOOR)
| (flat == self.WALL)
| (flat == self.ENTRANCE)
)
flat[~keep] = self.MASK_ID
return flat
def __getitem__(self, idx):
item = self.data[idx]
raw_np = self._augment(np.array(item['map'], dtype=np.int64)) # [H, W]
subset = self._choose_subset()
masked_np = self._apply_subset(raw_np, subset) # [H*W]
raw_flat = raw_np.reshape(-1) # [H*W]
return {
"raw_map": torch.LongTensor(raw_flat), # VQ-VAE 编码器输入
"masked_map": torch.LongTensor(masked_np), # MaskGIT 输入
"target_map": torch.LongTensor(raw_flat.copy()), # CE loss ground truth
"subset": subset, # 调试/统计用
}
if __name__ == "__main__":
import os
data_path = os.path.join(os.path.dirname(__file__), '..', 'data', 'ginka-dataset.json')
ds = GinkaVQDataset(data_path)
print(f"数据集大小: {len(ds)}")
subset_count = {'A': 0, 'B': 0, 'C': 0, 'D': 0}
for i in range(200):
sample = ds[i % len(ds)]
subset_count[sample['subset']] += 1
raw = sample['raw_map']
masked = sample['masked_map']
target = sample['target_map']
print(f"raw_map shape={raw.shape}, dtype={raw.dtype}")
print(f"masked_map shape={masked.shape}, dtype={masked.dtype}")
print(f"target_map shape={target.shape}, dtype={target.dtype}")
print(f"被 mask 的位置数: {(masked == 15).sum().item()} / {masked.numel()}")
print(f"\n200 次采样子集分布: {subset_count}")

View File

@ -15,9 +15,15 @@ class Transformer(nn.Module):
num_layers=num_layers
)
def forward(self, x):
# x: [B, L, d_model]
m = self.encoder(x)
out = self.decoder(x, m)
def forward(self, x, memory=None):
# x: [B, S, d_model] 地图 token 序列
# memory: [B, L, d_model] 可选的 z 投影,用于 cross-attention
# 若 memory 为 None则退化为原始自编解码行为向后兼容
enc_out = self.encoder(x)
if memory is not None:
# encoder 输出作为 queryz 作为 key/value
out = self.decoder(enc_out, memory)
else:
out = self.decoder(x, enc_out)
return out

View File

@ -1,88 +1,117 @@
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..utils import print_memory
from .cond import GinkaMaskGITCond
from .maskGIT import Transformer
class GinkaMaskGIT(nn.Module):
"""
改造后的 MaskGIT 地图生成模型
以掩码地图序列和 VQ-VAE 输出的离散隐变量 z 为输入
通过 Transformer encoder-decoder 结构预测被遮盖位置的 tile 类别
z 通过 cross-attention 注入到 Transformer decoder
作为风格/多样性控制信号而非结构重建指导
"""
def __init__(
self, num_classes=16, heatmap_channel=4, d_model=256,
dim_ff=512, nhead=8, num_layers=4, map_size=13*13
self,
num_classes: int = 16,
d_model: int = 192,
d_z: int = 64,
dim_ff: int = 512,
nhead: int = 8,
num_layers: int = 4,
map_size: int = 13 * 13,
z_dropout: float = 0.1,
):
"""
Args:
num_classes: tile 类别数 MASK token=15
d_model: Transformer 内部维度
d_z: VQ-VAE 码字嵌入维度需与 GinkaVQVAE.d_z 一致
dim_ff: 前馈网络隐层维度
nhead: 注意力头数
num_layers: Transformer 层数
map_size: 地图 token 总数H * W
z_dropout: 训练时随机替换 z 为随机码字的概率提升鲁棒性
"""
super().__init__()
self.z_dropout = z_dropout
# Tile 嵌入 + 位置编码
self.tile_embedding = nn.Embedding(num_classes, d_model)
self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model))
cond_channels = [d_model // 4, d_model // 2, d_model]
self.cond_encoder = GinkaMaskGITCond(input_channel=heatmap_channel, channels=cond_channels)
self.gate_encoder = nn.Sequential(
nn.Conv2d(cond_channels[2], cond_channels[2], 3, padding=1, padding_mode="replicate"),
nn.BatchNorm2d(cond_channels[2]),
nn.GELU()
self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model) * 0.02)
# z 投影:将 VQ 码字从 d_z 维映射到 d_model 维,供 cross-attention 使用
self.z_proj = nn.Sequential(
nn.Linear(d_z, d_model),
nn.LayerNorm(d_model),
)
self.cond_gate = nn.Sequential(
nn.Linear(cond_channels[2] * 2, cond_channels[2]),
nn.LayerNorm(cond_channels[2]),
nn.Dropout(0.3),
nn.GELU(),
nn.Linear(cond_channels[2], cond_channels[2])
# Transformerencoder 做 map token 自注意力decoder 做与 z 的 cross-attention
self.transformer = Transformer(
d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers
)
self.transformer = Transformer(d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers)
self.output_fc = nn.Sequential(
nn.Linear(d_model, num_classes)
)
def forward(self, map: torch.Tensor, heatmap: torch.Tensor):
# map: [B, H * W]
# heatmap: [B, C, H, W]
# output: [B, H * W, num_classes]
heatmap = self.cond_encoder(heatmap) # [B, d_model, H, W]
B, C, H, W = heatmap.shape
heatmap_gate = self.gate_encoder(heatmap)
heatmap_mean = F.avg_pool2d(heatmap_gate, (H, W)) # [B, d_model, 1, 1]
heatmap_max = F.max_pool2d(heatmap_gate, (H, W)) # [B, d_model, 1, 1]
gate_input = torch.cat([heatmap_mean, heatmap_max], dim=1).squeeze(2).squeeze(2)
gate = self.cond_gate(gate_input) # [B, d_model]
heatmap = heatmap * torch.sigmoid(gate).unsqueeze(2).unsqueeze(2)
heatmap = heatmap.view(B, C, H * W).permute(0, 2, 1)
x = self.tile_embedding(map) + heatmap
x = x + self.pos_embedding
x = self.transformer(x)
logits = self.output_fc(x)
self.output_fc = nn.Linear(d_model, num_classes)
def forward(self, map: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
"""
Args:
map: [B, H*W] 掩码后的地图 token 序列MASK token = 15
z: [B, L, d_z] VQ-VAE 量化后的离散隐变量
Returns:
logits: [B, H*W, num_classes]
"""
# z dropout训练时以一定概率将 z 替换为随机均匀噪声,
# 模拟推理时随机采样 z 的分布,避免模型过拟合于精确的 z 语义
if self.training and self.z_dropout > 0:
mask = torch.rand(z.shape[0], 1, 1, device=z.device) < self.z_dropout
rand_z = torch.randn_like(z)
z = torch.where(mask, rand_z, z)
# 投影 z 到 d_model 维度
z_mem = self.z_proj(z) # [B, L, d_model]
# tile embedding + 位置编码
x = self.tile_embedding(map) # [B, H*W, d_model]
x = x + self.pos_embedding # [B, H*W, d_model]
# Transformerencoder 做 map 自注意力decoder cross-attend z
x = self.transformer(x, memory=z_mem) # [B, H*W, d_model]
logits = self.output_fc(x) # [B, H*W, num_classes]
return logits
if __name__ == "__main__":
device = torch.device("cpu")
map = torch.randint(0, 16, [1, 169]).to(device)
heatmap = torch.rand(1, 4, 13, 13).to(device)
# 初始化模型
model = GinkaMaskGIT().to(device)
print_memory("初始化后")
# 前向传播
start = time.perf_counter()
output = model(map, heatmap)
end = time.perf_counter()
print_memory("前向传播后")
print(f"推理耗时: {end - start}")
print(f"输出形状: output={output.shape}")
print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}")
print(f"Condition Encoder parameters: {sum(p.numel() for p in model.cond_encoder.parameters())}")
print(f"Condition Gate parameters: {sum(p.numel() for p in model.cond_gate.parameters())}")
print(f"MaskGIT parameters: {sum(p.numel() for p in model.transformer.parameters())}")
print(f"Output parameters: {sum(p.numel() for p in model.output_fc.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
model = GinkaMaskGIT(
num_classes=16,
d_model=192,
d_z=64,
dim_ff=512,
nhead=8,
num_layers=4,
map_size=13 * 13,
).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"总参数量: {total_params:,} ({total_params / 1e6:.3f}M)")
for name, module in model.named_children():
n = sum(p.numel() for p in module.parameters())
print(f" {name}: {n:,}")
map_input = torch.randint(0, 16, (4, 13 * 13)).to(device) # [B=4, 169]
z_input = torch.randn(4, 2, 64).to(device) # [B=4, L=2, d_z=64]
model.train()
logits = model(map_input, z_input)
print(f"\nlogits shape: {logits.shape}") # [4, 169, 16]
print_memory(device, "前向传播后")

520
ginka/train_vq.py Normal file
View File

@ -0,0 +1,520 @@
"""
联合训练脚本VQ-VAE + MaskGIT
总损失 = L_CEMaskGIT 重建损失+ beta * L_commit + gamma * L_entropy
验证阶段对四种子集A/B/C/D分别输出图片
每条样本额外采样 N_Z_SAMPLES 个随机 z
便于直观对比同条件不同 z 下的生成差异
用法示例
python -m ginka.train_vq
python -m ginka.train_vq --resume True --state result/joint/joint-10.pth
"""
import argparse
import math
import os
import sys
import random
from datetime import datetime
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader
from .vqvae.model import GinkaVQVAE
from .maskGIT.model import GinkaMaskGIT
from .dataset import GinkaVQDataset
from shared.image import matrix_to_image_cv
# ---------------------------------------------------------------------------
# 超参数
# ---------------------------------------------------------------------------
BATCH_SIZE = 64
NUM_CLASSES = 16
MASK_TOKEN = 15
GENERATE_STEP = 12 # 推理时 MaskGIT 迭代步数
MAP_SIZE = 13 * 13
MAP_H = MAP_W = 13
LABEL_SMOOTHING = 0.0
# VQ-VAE 超参
VQ_L = 2 # summary token 数量(即 z 的序列长度)
VQ_K = 16 # codebook 大小
VQ_D_Z = 64 # codebook 嵌入维度
VQ_D_MODEL= 128
VQ_NHEAD = 4
VQ_LAYERS = 2
VQ_DIM_FF = 256
VQ_BETA = 0.25 # commit loss 权重
VQ_GAMMA = 0.1 # entropy loss 权重
# MaskGIT 超参
MG_D_MODEL = 192
MG_NHEAD = 8
MG_LAYERS = 4
MG_DIM_FF = 512
MG_Z_DROPOUT= 0.15 # 训练时以此概率把 z 替换为随机噪声
# 验证时对每条样本额外采样的 z 数量0 = 只用真实 z
N_Z_SAMPLES = 3
# 四个子集 A/B/C/D 的采样权重(训练集与验证集共用)
SUBSET_WEIGHTS = (0.5, 0.2, 0.2, 0.1)
# ---------------------------------------------------------------------------
# 设备
# ---------------------------------------------------------------------------
device = torch.device(
"cuda:1" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)
os.makedirs("result/joint", exist_ok=True)
os.makedirs("result/joint_img", exist_ok=True)
disable_tqdm = not sys.stdout.isatty()
# ---------------------------------------------------------------------------
# 参数解析
# ---------------------------------------------------------------------------
def parse_arguments():
parser = argparse.ArgumentParser(description="VQ-VAE + MaskGIT 联合训练")
parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--state", type=str, default="result/joint/joint-10.pth",
help="续训时加载的检查点路径")
parser.add_argument("--train", type=str, default="data/ginka-dataset.json")
parser.add_argument("--validate", type=str, default="data/ginka-eval.json")
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--checkpoint", type=int, default=5,
help="每隔多少 epoch 保存检查点并验证")
parser.add_argument("--load_optim", type=bool, default=True)
return parser.parse_args()
# ---------------------------------------------------------------------------
# MaskGIT 推理cosine schedule 迭代解码)
# ---------------------------------------------------------------------------
@torch.no_grad()
def maskgit_generate(
model_mg: GinkaMaskGIT,
z: torch.Tensor,
steps: int = GENERATE_STEP,
init_map: torch.Tensor = None,
) -> torch.Tensor:
"""
迭代生成地图cosine schedule unmasking
Args:
model_mg: GinkaMaskGITeval 模式
z: [B, L, d_z] 条件 z
steps: 解码步数
init_map: [B, MAP_SIZE] 可选初始地图 MASK 位置在生成过程中保持固定
None 时从全 MASK 开始自由生成
Returns:
map_out: [B, MAP_SIZE]
"""
B = z.shape[0]
if init_map is None:
map_seq = torch.full((B, MAP_SIZE), MASK_TOKEN, device=device)
else:
map_seq = init_map.clone().to(device)
# 记录初始 MASK 位置,这些位置才需要生成
generatable = (map_seq == MASK_TOKEN) # [B, S] bool
for step in range(steps):
if not generatable.any():
break
logits = model_mg(map_seq, z) # [B, S, C]
probs = F.softmax(logits, dim=-1)
dist = torch.distributions.Categorical(probs)
sampled = dist.sample() # [B, S]
# 计算置信度;固定位置设为 +inf确保不会被选为“继续保持 MASK”
confidences = torch.gather(probs, -1, sampled.unsqueeze(-1)).squeeze(-1)
confidences = confidences.masked_fill(~generatable, float('inf'))
ratio = math.cos(((step + 1) / steps) * math.pi / 2)
new_map = map_seq.clone()
for b in range(B):
n_gen = int(generatable[b].sum().item())
n_keep = int(ratio * n_gen) # 本步仍保持 MASK 的位置数
if n_keep > 0:
_, keep_idx = torch.topk(confidences[b], k=n_keep, largest=False)
pred_b = sampled[b].clone()
pred_b[keep_idx] = MASK_TOKEN
new_map[b] = torch.where(generatable[b], pred_b, map_seq[b])
else:
new_map[b] = torch.where(generatable[b], sampled[b], map_seq[b])
map_seq = new_map
return map_seq
# ---------------------------------------------------------------------------
# 可视化工具
# ---------------------------------------------------------------------------
def make_map_image(map_flat: torch.Tensor, tile_dict: dict) -> np.ndarray:
"""将 [MAP_SIZE] 的 tensor 转成 RGB 图片numpy"""
arr = map_flat.cpu().numpy().reshape(MAP_H, MAP_W)
return matrix_to_image_cv(arr, tile_dict)
def hstack_images(imgs: list, gap: int = 4, color=(255, 255, 255)) -> np.ndarray:
"""将多张等高图片横向拼接,之间插入白色竖线。"""
H = imgs[0].shape[0]
vline = np.full((H, gap, 3), color, dtype=np.uint8)
result = imgs[0]
for img in imgs[1:]:
result = np.concatenate([result, vline, img], axis=1)
return result
def label_image(img: np.ndarray, text: str, font_scale: float = 0.45) -> np.ndarray:
"""在图片顶部加一行文字标签(就地修改并返回)。"""
bar_h = 16
bar = np.full((bar_h, img.shape[1], 3), (40, 40, 40), dtype=np.uint8)
cv2.putText(
bar, text, (2, bar_h - 3),
cv2.FONT_HERSHEY_SIMPLEX, font_scale,
(200, 200, 200), 1, cv2.LINE_AA
)
return np.concatenate([bar, img], axis=0)
def make_random_wall_seed(ratio_min: float = 0.02, ratio_max: float = 0.08) -> torch.Tensor:
"""
在全 MASK 地图上随机放置少量墙壁作为推理种子用于完全随机生成场景
Returns:
[1, MAP_SIZE] MASK=15 背景 + 随机置少量墙壁tile=1
"""
ratio = random.uniform(ratio_min, ratio_max)
n_wall = max(2, int(MAP_SIZE * ratio))
seed = torch.full((1, MAP_SIZE), MASK_TOKEN, dtype=torch.long, device=device)
idx = torch.randperm(MAP_SIZE)[:n_wall]
seed[0, idx] = 1 # wall
return seed
@torch.no_grad()
def validate(
model_vq: GinkaVQVAE,
model_mg: GinkaMaskGIT,
dataloader_val: DataLoader,
tile_dict: dict,
epoch: int,
):
"""
验证函数计算 val loss 并输出 5 类推理场景的对比图
场景说明 epoch 建立子文件夹避免图片堆积
场景1 (scene1_completion) : 子集 A标准随机掩码补全
: ground truth | masked input | z_real pred | z_real gen | z_rand×N
场景2 (scene2_wall) : 子集 B仅墙壁+空地 生成完整地图
: ground truth | wall-only input | z_real gen | z_rand×N
场景3 (scene3_sparse) : 子集 C稀疏墙壁条件 生成完整地图
: ground truth | sparse wall input | z_real gen | z_rand×N
场景4 (scene4_entrance) : 子集 D墙壁+入口 生成完整地图
: ground truth | wall+entrance input | z_real gen | z_rand×N
场景5 (scene5_random) : 无数据集参照随机稀疏墙壁种子 完全随机生成
: random seed | z_rand×(N+1)
"""
model_vq.eval()
model_mg.eval()
# 按 epoch 建立独立子文件夹,保留每次验证结果方便回溯
epoch_dir = f"result/joint_img/e{epoch:04d}"
os.makedirs(epoch_dir, exist_ok=True)
val_loss_total = 0.0
val_steps = 0
captured = {s: None for s in ('A', 'B', 'C', 'D')}
# ── 计算 val loss + 捕获各子集样本 ──────────────────────────────────────
for batch in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm):
raw_map = batch["raw_map"].to(device) # [B, 169]
masked_map = batch["masked_map"].to(device) # [B, 169]
target_map = batch["target_map"].to(device) # [B, 169]
subsets = batch["subset"] # list of str
B = raw_map.shape[0]
z_q, _, vq_loss = model_vq(raw_map)
logits = model_mg(masked_map, z_q)
mask = (masked_map == MASK_TOKEN)
ce_loss = F.cross_entropy(
logits.permute(0, 2, 1), target_map,
reduction='none', label_smoothing=LABEL_SMOOTHING
)
masked_ce = (ce_loss * mask).sum() / (mask.sum() + 1e-6)
val_loss_total += (masked_ce + vq_loss).item()
val_steps += 1
for i in range(B):
s = subsets[i]
if captured[s] is None:
captured[s] = {
"raw": raw_map[i:i+1].clone(),
"masked": masked_map[i:i+1].clone(),
"z_q": z_q[i:i+1].clone(),
}
if all(v is not None for v in captured.values()):
break
# ── 公共辅助:对给定条件地图随机采样 n 次 z 并迭代生成 ──────────────────
def _rand_gens(cond_map, n):
imgs = []
for i in range(n):
z_r = model_vq.sample(1, device)
gen = maskgit_generate(model_mg, z_r, init_map=cond_map)
imgs.append(label_image(make_map_image(gen[0], tile_dict), f"z_rand_{i + 1}"))
return imgs
# ── 场景1标准掩码补全子集 A─────────────────────────────────────────
if captured['A'] is not None:
cap = captured['A']
raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q']
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
cond_img = label_image(make_map_image(cond[0], tile_dict), "masked input")
# 单步 argmax 预测(观察模型对掩码位置的瞬时判断)
pred = model_mg(cond, z_q).argmax(dim=-1)[0]
pred_img = label_image(make_map_image(pred, tile_dict), "z_real pred")
# 迭代生成(从掩码输入出发,真实 z
gen_real = maskgit_generate(model_mg, z_q, init_map=cond)
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
row = [real_img, cond_img, pred_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
cv2.imwrite(f"{epoch_dir}/scene1_completion.png", hstack_images(row))
# ── 场景2墙壁辅助生成子集 B─────────────────────────────────────────
if captured['B'] is not None:
cap = captured['B']
raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q']
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
cond_img = label_image(make_map_image(cond[0], tile_dict), "wall-only input")
gen_real = maskgit_generate(model_mg, z_q, init_map=cond)
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
cv2.imwrite(f"{epoch_dir}/scene2_wall.png", hstack_images(row))
# ── 场景3稀疏墙壁条件生成子集 C────────────────────────────────────
if captured['C'] is not None:
cap = captured['C']
raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q']
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
cond_img = label_image(make_map_image(cond[0], tile_dict), "sparse wall input")
gen_real = maskgit_generate(model_mg, z_q, init_map=cond)
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
cv2.imwrite(f"{epoch_dir}/scene3_sparse.png", hstack_images(row))
# ── 场景4墙壁+入口条件生成(子集 D───────────────────────────────────
if captured['D'] is not None:
cap = captured['D']
raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q']
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
cond_img = label_image(make_map_image(cond[0], tile_dict), "wall+entrance input")
gen_real = maskgit_generate(model_mg, z_q, init_map=cond)
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
cv2.imwrite(f"{epoch_dir}/scene4_entrance.png", hstack_images(row))
# ── 场景5完全随机生成无数据集参照──────────────────────────────────
# 随机稀疏墙壁种子,对应推理时"不提供任何条件,直接生成"的场景
rand_seed = make_random_wall_seed() # [1, MAP_SIZE]
seed_img = label_image(make_map_image(rand_seed[0], tile_dict), "random seed")
row = [seed_img] + _rand_gens(rand_seed, N_Z_SAMPLES + 1) # 多采一个 z 展示多样性
cv2.imwrite(f"{epoch_dir}/scene5_random.png", hstack_images(row))
avg_val_loss = val_loss_total / max(val_steps, 1)
return avg_val_loss
# ---------------------------------------------------------------------------
# 主训练函数
# ---------------------------------------------------------------------------
def train():
print(f"Using device: {device}")
args = parse_arguments()
# ---- 模型 ----
model_vq = GinkaVQVAE(
num_classes=NUM_CLASSES,
L=VQ_L, K=VQ_K, d_z=VQ_D_Z,
d_model=VQ_D_MODEL, nhead=VQ_NHEAD,
num_layers=VQ_LAYERS, dim_ff=VQ_DIM_FF,
map_size=MAP_SIZE,
beta=VQ_BETA, gamma=VQ_GAMMA,
).to(device)
model_mg = GinkaMaskGIT(
num_classes=NUM_CLASSES,
d_model=MG_D_MODEL, d_z=VQ_D_Z,
dim_ff=MG_DIM_FF, nhead=MG_NHEAD,
num_layers=MG_LAYERS,
map_size=MAP_SIZE,
z_dropout=MG_Z_DROPOUT,
).to(device)
vq_params = sum(p.numel() for p in model_vq.parameters())
mg_params = sum(p.numel() for p in model_mg.parameters())
print(f"VQ-VAE 参数量: {vq_params:,} ({vq_params/1e6:.3f}M)")
print(f"MaskGIT 参数量: {mg_params:,} ({mg_params/1e6:.3f}M)")
print(f"Total 参数量: {vq_params+mg_params:,} ({(vq_params+mg_params)/1e6:.3f}M)")
# ---- 数据集 ----
dataset_train = GinkaVQDataset(
args.train,
subset_weights=SUBSET_WEIGHTS,
)
dataset_val = GinkaVQDataset(
args.validate,
subset_weights=SUBSET_WEIGHTS,
)
dataloader_train = DataLoader(
dataset_train, batch_size=BATCH_SIZE, shuffle=True,
num_workers=0, pin_memory=(device.type == "cuda"),
)
dataloader_val = DataLoader(
dataset_val, batch_size=8, shuffle=True,
num_workers=0,
)
# ---- 优化器(联合训练,两个模型共用一个 optimizer----
all_params = list(model_vq.parameters()) + list(model_mg.parameters())
optimizer = optim.AdamW(all_params, lr=2e-4, weight_decay=1e-2)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=args.epochs, eta_min=1e-6
)
# ---- 续训 ----
start_epoch = 0
if args.resume:
ckpt = torch.load(args.state, map_location=device)
model_vq.load_state_dict(ckpt["vq_state"], strict=False)
model_mg.load_state_dict(ckpt["mg_state"], strict=False)
if args.load_optim and ckpt.get("optim_state") is not None:
optimizer.load_state_dict(ckpt["optim_state"])
start_epoch = ckpt.get("epoch", 0)
print(f"从 epoch {start_epoch} 接续训练。")
# ---- tile 贴图(用于验证可视化)----
tile_dict = {}
for file in os.listdir("tiles"):
name = os.path.splitext(file)[0]
img = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
if img is not None:
tile_dict[name] = img
# ---- 训练循环 ----
for epoch in tqdm(range(start_epoch, start_epoch + args.epochs),
desc="Joint Training", disable=disable_tqdm):
model_vq.train()
model_mg.train()
loss_total = 0.0
ce_total = 0.0
vq_loss_total = 0.0
subset_stats = {'A': 0, 'B': 0, 'C': 0, 'D': 0}
for batch in tqdm(dataloader_train, leave=False,
desc="Epoch Progress", disable=disable_tqdm):
raw_map = batch["raw_map"].to(device) # [B, 169]
masked_map = batch["masked_map"].to(device) # [B, 169]
target_map = batch["target_map"].to(device) # [B, 169]
for s in batch["subset"]:
subset_stats[s] = subset_stats.get(s, 0) + 1
# ---- 前向传播 ----
# 1. VQ-VAE 编码真实地图 → z_q
z_q, _, vq_loss = model_vq(raw_map) # z_q: [B, L, d_z]
# 2. MaskGIT 以掩码地图 + z 预测原始 tile
logits = model_mg(masked_map, z_q) # [B, 169, C]
# 3. 只对被 mask 的位置计算 CE loss
mask = (masked_map == MASK_TOKEN) # [B, 169] bool
ce_loss = F.cross_entropy(
logits.permute(0, 2, 1), target_map,
reduction='none', label_smoothing=LABEL_SMOOTHING
)
masked_ce = (ce_loss * mask).sum() / (mask.sum() + 1e-6)
# 4. 联合损失
loss = masked_ce + vq_loss
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0)
optimizer.step()
loss_total += loss.detach().item()
ce_total += masked_ce.detach().item()
vq_loss_total += vq_loss.detach().item()
scheduler.step()
n = len(dataloader_train)
tqdm.write(
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
f"Epoch {epoch + 1:4d} | "
f"Loss {loss_total/n:.5f} "
f"CE {ce_total/n:.5f} "
f"VQ {vq_loss_total/n:.5f} | "
f"LR {scheduler.get_last_lr()[0]:.6f} | "
f"Subsets {subset_stats}"
)
# ---- 检查点 + 验证 ----
if (epoch + 1) % args.checkpoint == 0:
ckpt_path = f"result/joint/joint-{epoch + 1}.pth"
torch.save({
"epoch": epoch + 1,
"vq_state": model_vq.state_dict(),
"mg_state": model_mg.state_dict(),
"optim_state":optimizer.state_dict(),
}, ckpt_path)
tqdm.write(f" 检查点已保存: {ckpt_path}")
val_loss = validate(
model_vq, model_mg, dataloader_val, tile_dict, epoch + 1
)
tqdm.write(
f"[Validate] Epoch {epoch + 1:4d} | Val Loss {val_loss:.5f}"
)
# 恢复训练模式
model_vq.train()
model_mg.train()
print("训练结束。")
torch.save({
"epoch": start_epoch + args.epochs,
"vq_state": model_vq.state_dict(),
"mg_state": model_mg.state_dict(),
}, "result/joint/joint_final.pth")
# ---------------------------------------------------------------------------
if __name__ == "__main__":
torch.set_num_threads(4)
train()

1
ginka/vqvae/__init__.py Normal file
View File

@ -0,0 +1 @@
from .model import GinkaVQVAE

186
ginka/vqvae/model.py Normal file
View File

@ -0,0 +1,186 @@
import torch
import torch.nn as nn
from .quantize import VectorQuantizer
from typing import Tuple
class GinkaVQVAE(nn.Module):
"""
VQ-VAE 风格地图编码器
将一张完整的地图[B, H*W] 整数 tile ID 序列编码为 L 个离散码字
输出 z [B, L, d_z] 作为 MaskGIT 模型的生成条件
架构
tile embedding + 位置编码
L 个可学习 summary token拼接到序列头部
Transformer EncoderPre-LN自注意力
取前 L 个输出
线性投影到 d_z
VectorQuantizer直通估计 + 熵最大化正则
设计约束
- 参数量目标 < 1M
- 不含解码器z 的语义由 MaskGIT 端的交叉熵损失间接约束
- z 定位为风格/多样性控制信号而非结构重建指导
"""
def __init__(
self,
num_classes: int = 16,
L: int = 2,
K: int = 16,
d_z: int = 64,
d_model: int = 128,
nhead: int = 4,
num_layers: int = 2,
dim_ff: int = 256,
map_size: int = 13 * 13,
beta: float = 0.25,
gamma: float = 0.1,
vq_temp: float = 1.0,
):
"""
Args:
num_classes: tile 类别数 MASK token
L: 码字序列长度 z 的序列维度
K: codebook 大小码字总数
d_z: 码字嵌入维度
d_model: Transformer 内部维度
nhead: 注意力头数
num_layers: Transformer 层数
dim_ff: 前馈网络隐层维度
map_size: 地图 token 总数H * W
beta: 承诺损失权重
gamma: 熵正则损失权重
vq_temp: VQ 软分配 softmax 温度
"""
super().__init__()
self.L = L
self.K = K
self.d_z = d_z
self.beta = beta
self.gamma = gamma
# Tile 嵌入
self.tile_embedding = nn.Embedding(num_classes, d_model)
# 地图位置编码(仅覆盖 map_size 个位置,不含 summary token
self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model) * 0.02)
# L 个可学习 summary token拼接到序列头部
self.summary_tokens = nn.Parameter(torch.randn(1, L, d_model) * 0.02)
# Pre-LN Transformer Encoder训练更稳定
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_ff,
batch_first=True,
activation='gelu',
norm_first=True, # Pre-LN
dropout=0.0,
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# 将 Transformer 输出投影到 codebook 维度 d_z
self.proj = nn.Sequential(
nn.Linear(d_model, d_z),
nn.LayerNorm(d_z),
)
# 向量量化层
self.vq = VectorQuantizer(K=K, d_z=d_z, temp=vq_temp)
def encode(self, map: torch.Tensor) -> torch.Tensor:
"""
将地图编码为量化前的连续向量序列
Args:
map: [B, H*W] 整数 tile ID
Returns:
z_e: [B, L, d_z] 量化前的编码向量
"""
B = map.shape[0]
x = self.tile_embedding(map) # [B, H*W, d_model]
x = x + self.pos_embedding # [B, H*W, d_model]
summary = self.summary_tokens.expand(B, -1, -1) # [B, L, d_model]
x = torch.cat([summary, x], dim=1) # [B, L+H*W, d_model]
x = self.transformer(x) # [B, L+H*W, d_model]
z_e = self.proj(x[:, :self.L]) # [B, L, d_z]
return z_e
def forward(self, map: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
完整前向传播编码 量化 计算损失
Args:
map: [B, H*W] 整数 tile ID训练时传入完整真实地图
Returns:
z_q: [B, L, d_z] 量化后的 z含直通梯度 MaskGIT 使用
indices: [B, L] 每个位置对应的码字索引
vq_loss: scalar VQ 总损失 = beta * commit_loss + gamma * entropy_loss
"""
z_e = self.encode(map)
z_q, indices, commit_loss, entropy_loss = self.vq(z_e)
vq_loss = self.beta * commit_loss + self.gamma * entropy_loss
return z_q, indices, vq_loss
def sample(self, B: int, device: torch.device) -> torch.Tensor:
"""
推理阶段 codebook 中随机均匀采样 L 个码字
Args:
B: batch size
device: 目标设备
Returns:
z: [B, L, d_z]
"""
indices = torch.randint(0, self.K, (B, self.L), device=device)
z = self.vq.codebook(indices) # [B, L, d_z]
return z
if __name__ == "__main__":
device = torch.device("cpu")
model = GinkaVQVAE(
num_classes=16,
L=2,
K=16,
d_z=64,
d_model=128,
nhead=4,
num_layers=2,
dim_ff=256,
map_size=13 * 13,
).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"总参数量: {total_params:,} ({total_params / 1e6:.3f}M)")
# 分模块参数统计
for name, module in model.named_children():
n = sum(p.numel() for p in module.parameters())
print(f" {name}: {n:,}")
# 前向传播测试
map_input = torch.randint(0, 15, (4, 13 * 13)).to(device) # [B=4, 169]
z_q, indices, vq_loss = model(map_input)
print(f"\nz_q shape: {z_q.shape}") # [4, 2, 64]
print(f"indices shape:{indices.shape}") # [4, 2]
print(f"vq_loss: {vq_loss.item():.4f}")
# 推理采样测试
z_sample = model.sample(B=4, device=device)
print(f"sample shape: {z_sample.shape}") # [4, 2, 64]

81
ginka/vqvae/quantize.py Normal file
View File

@ -0,0 +1,81 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
class VectorQuantizer(nn.Module):
"""
向量量化层Vector Quantization
将连续的编码向量序列映射到离散的 codebook 码字索引
并通过直通估计Straight-Through Estimator保持梯度流
均匀分布正则化采用软分配熵最大化方案
通过对距离做 softmax 得到软分配概率计算平均码字使用率的熵
最小化负熵以鼓励所有码字被均等使用
"""
def __init__(self, K: int, d_z: int, temp: float = 1.0):
"""
Args:
K: codebook 大小码字数量
d_z: 码字嵌入维度
temp: 软分配 softmax 温度越小越接近 hard assignment
"""
super().__init__()
self.K = K
self.d_z = d_z
self.temp = temp
self.codebook = nn.Embedding(K, d_z)
nn.init.uniform_(self.codebook.weight, -1.0 / K, 1.0 / K)
def forward(self, z_e: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
z_e: [B, L, d_z] 编码器输出的连续向量序列
Returns:
z_q_st: [B, L, d_z] 量化后向量直通梯度
indices: [B, L] 每个位置对应的码字索引
commit_loss: scalar 承诺损失 ||z_e - sg(z_q)||^2
entropy_loss: scalar 负熵损失最小化 = 最大化码字使用均匀度
"""
B, L, d_z = z_e.shape
# 展平到 [B*L, d_z]
z_flat = z_e.reshape(B * L, d_z)
codebook_w = self.codebook.weight # [K, d_z]
# 计算 L2 距离:||z_e - e_k||^2 = ||z_e||^2 + ||e_k||^2 - 2 * z_e · e_k
# distances: [B*L, K]
distances = (
(z_flat ** 2).sum(dim=1, keepdim=True) # [B*L, 1]
+ (codebook_w ** 2).sum(dim=1) # [K]
- 2.0 * z_flat @ codebook_w.t() # [B*L, K]
)
# Hard assignment取最近码字索引
indices = distances.argmin(dim=1) # [B*L]
# 量化向量
z_q_flat = self.codebook(indices) # [B*L, d_z]
z_q = z_q_flat.reshape(B, L, d_z)
# 直通估计:前向传 z_q反向传 z_e 的梯度
z_q_st = z_e + (z_q - z_e).detach()
# 承诺损失:拉近编码向量与其对应的码字(仅更新编码器)
commit_loss = F.mse_loss(z_e, z_q.detach())
# 熵最大化正则:通过软分配计算平均码字使用率,最小化负熵
# soft_assign: [B*L, K],对距离做 softmax距离越小概率越大
soft_assign = F.softmax(-distances / self.temp, dim=1)
avg_assign = soft_assign.mean(dim=0) # [K],平均码字使用率
# entropy_loss = -H(p) = sum(p * log(p)),最小化即最大化熵
entropy_loss = (avg_assign * torch.log(avg_assign + 1e-10)).sum()
indices = indices.reshape(B, L)
return z_q_st, indices, commit_loss, entropy_loss