diff --git a/docs/vqvae-maskgit-design.md b/docs/vqvae-maskgit-design.md new file mode 100644 index 0000000..ede2202 --- /dev/null +++ b/docs/vqvae-maskgit-design.md @@ -0,0 +1,397 @@ +# Ginka 地图生成模型重构设计文档 + +## 背景与动机 + +### 现状与问题 + +当前模型采用**热力图(Heatmap)**作为生成条件,引导 MaskGIT 模型生成地图。实践中发现以下问题: + +1. **模型侧**:MaskGIT 无法有效利用热力图信息,热力图对生成结果的控制力弱; +2. **用户侧**:用户难以手动构造语义合理的热力图,交互成本高; +3. **表达侧**:热力图是高维连续信号,携带冗余信息,但实际指导生成的有效信息量有限。 + +### 改进目标 + +- 保留 MaskGIT 的地图生成范式,但将生成条件改为**低维离散隐变量 z**; +- z 由 **VQ-VAE 风格的编码器**从真实地图中抽取,形成一个可复用的 **codebook**; +- 训练期间通过 codebook 为 MaskGIT 提供多样性控制信号; +- 推理期间**直接随机采样 z**,无需用户提供任何额外输入。 + +--- + +## 整体架构 + +系统由两个子模型组成,训练可以分阶段进行,也可以联合训练。 + +``` +┌─────────────────────────────────────────────────────────┐ +│ 训练阶段 │ +│ │ +│ 真实地图 ──► [VQ-VAE 编码器] ──► z(离散码序列) │ +│ │ │ +│ ▼ │ +│ [Codebook] │ +│ │ │ +│ ▼ │ +│ 掩码地图 + z ──► [MaskGIT] ──► 预测被遮盖的图块 │ +└─────────────────────────────────────────────────────────┘ + +┌─────────────────────────────────────────────────────────┐ +│ 推理阶段 │ +│ │ +│ 随机采样 z ──► [MaskGIT 迭代解码] ──► 生成地图 │ +└─────────────────────────────────────────────────────────┘ +``` + +--- + +## 子模型一:VQ-VAE 风格编码器 + +**规模约束**:编码器参数量应控制在 **< 1M**。此前 MaskGIT 在热力图方案中已验证 4M 模型处于欠拟合与过拟合的边界,改用新方案整体规模不会有大幅变化,VQ-VAE 编码器作为辅助模块无需过大。 + +### 职责 + +将一张完整的地图(`[H, W]` 整数矩阵)编码为一个**离散的 z**,z 是从 codebook 中查得的嵌入向量(或向量序列)。 + +### 设计方案 + +编码器统一采用**纯 Transformer 结构**:地图尺寸(13×13)不大,Transformer 不会带来过大计算开销,同时更擅长捕捉图块之间的长程依赖,优于 CNN 方案。 + +#### 方案 A:全图单一码字(z 为标量 index) + +- 编码器将整张地图压缩为一个 `d_z` 维向量,再量化为 codebook 中的单个码字; +- z 是一个整数 index,对应 codebook 中一行嵌入; +- 结构最简单,但表达能力有限,难以捕捉地图的局部多样性。 + +``` +地图 [H, W] ──► Transformer 编码 ──► [d_z] ──► 量化 ──► index(标量) + │ + ▼ + codebook[index] → z [d_z] +``` + +#### 方案 B:空间码字序列(z 为序列,**已选定**) + +- 编码器将地图编码为 `[L, d_z]` 的特征序列(L 为码字数量); +- 每个位置独立量化为 codebook 中的一个码字; +- z 是 `L` 个 index 的组合,表达能力更强; +- 推理时随机采样 L 个 index 即可。 + +``` +地图 [H, W] ──► Transformer 编码 ──► [L, d_z] ──► 逐位量化 ──► [L 个 index] + │ + ▼ + codebook[index_0..L] → z [L, d_z] +``` + +### VQ-VAE 量化机制 + +标准向量量化(Vector Quantization): + +$$z_q = \arg\min_{e_k \in \mathcal{E}} \| z_e - e_k \|_2$$ + +其中 $\mathcal{E} = \{e_1, ..., e_K\}$ 为 codebook,$K$ 为码本大小。 + +**直通估计(Straight-Through Estimator)**用于反向传播: + +$$\frac{\partial \mathcal{L}}{\partial z_e} \approx \frac{\partial \mathcal{L}}{\partial z_q}$$ + +### 损失函数设计 + +总损失由三部分组成: + +$$\mathcal{L}_{VQVAE} = \mathcal{L}_{recon} + \beta \cdot \mathcal{L}_{commit} + \gamma \cdot \mathcal{L}_{uniform}$$ + +| 损失项 | 说明 | +| ----------------------- | ----------------------------------------------------------- | +| $\mathcal{L}_{recon}$ | 重建损失,确保 z 能恢复地图信息(可选,视是否加解码器而定) | +| $\mathcal{L}_{commit}$ | 承诺损失 $\|z_e - \text{sg}(z_q)\|^2$,拉近编码向量与码字 | +| $\mathcal{L}_{uniform}$ | **均匀分布正则化**,鼓励各码字被均等使用(见下节) | + +#### 均匀分布正则化 + +目标:让所有 `K` 个码字在训练中被均等使用,避免 codebook collapse。 + +方案一:**熵最大化** + +$$\mathcal{L}_{uniform} = -H(p) = \sum_{k=1}^{K} p_k \log p_k$$ + +其中 $p_k$ 为码字 $k$ 在当前 batch 中被选中的频率(使用 EMA 估计)。 + +方案二:**EMA 更新 + 重置** + +- 使用指数移动平均(EMA)更新 codebook,不通过梯度更新; +- 定期检测使用率低于阈值的码字,将其重置为当前 batch 中随机样本点。 + +方案三:**codebook 使用 KL 散度** + +$$\mathcal{L}_{uniform} = \text{KL}(p \| \mathcal{U}_{K})$$ + +其中 $\mathcal{U}_K$ 为均匀分布,即期望每个码字被选中概率为 $1/K$。 + +**初始方案**:采用**熵最大化**(方案一),通过梯度直接优化,实现简单;若后续出现 codebook collapse 问题,再引入 EMA 更新 + 重置机制。 + +### 是否需要解码器 + +| | 有解码器 | 无解码器 | +| ---------- | -------------------------- | ----------------------------------------- | +| 重建损失 | 有,z 需要还原地图 | 无,z 的质量由 MaskGIT 的生成质量间接约束 | +| 训练复杂度 | 较高 | 较低 | +| z 语义性 | 强,z 明确包含地图结构信息 | 弱,z 更像风格/多样性控制 | +| 推荐场景 | 需要 z 携带结构语义 | 仅用于多样性控制 | + +**已确认:不加解码器**,直接与 MaskGIT 联合训练。z 的语义由 MaskGIT 端的生成损失反向传播决定。z 定位为风格与多样性控制信号,而非结构重建指导——若 z 能完整重建地图,模型会忽略其他条件直接依赖 z,反而降低可控性与泛化能力。 + +相应地,损失函数简化为: + +$$\mathcal{L} = \mathcal{L}_{CE}(\text{MaskGIT}) + \beta \cdot \mathcal{L}_{commit} + \gamma \cdot \mathcal{L}_{uniform}$$ + +--- + +## 子模型二:MaskGIT 改造 + +### 现有输入 + +```python +def forward(self, map: torch.Tensor, heatmap: torch.Tensor): + # map: [B, H * W] + # heatmap: [B, C, H, W] +``` + +### 改造后输入 + +```python +def forward(self, map: torch.Tensor, z: torch.Tensor): + # map: [B, H * W] — 掩码后的地图 token 序列 + # z: [B, L, d_z] — 离散隐变量(方案B)或 [B, d_z](方案A) +``` + +移除原有的 `GinkaMaskGITCond`(热力图 CNN 编码器)和 `gate_encoder`,用 z 替代其功能。 + +### z 注入 MaskGIT 的方式(待选) + +#### 方式一:Cross-Attention(推荐) + +Transformer 解码器天然支持 cross-attention: + +``` +map token sequence ──► self-attention ──► cross-attention ◄── z + │ + ▼ + 预测 logits +``` + +z 作为 memory 输入到 `TransformerDecoder`,map token 序列作为 query。 + +优点:z 可以是任意长度的序列(方案B友好),表达力强。**已选定此方案。** + +#### 方式二:z 拼接到序列头部(Prefix Token) + +将 z 的每个码字嵌入作为前缀 token,拼接到 map token 序列前: + +``` +[z_0, z_1, ..., z_L, tile_0, tile_1, ..., tile_{H*W}] +``` + +输入到 TransformerEncoder 做自注意力,位置编码需相应扩展。 + +优点:实现简单,无需 encoder-decoder 分离。 + +#### 方式三:Add/FiLM 调制 + +将 z 投影后,用 FiLM(Feature-wise Linear Modulation)调制每一层的特征: + +$$y = \gamma(z) \odot x + \beta(z)$$ + +优点:参数高效;缺点:z 必须先聚合为单向量(适合方案A)。 + +--- + +## 用户使用场景 + +调研结果显示,用户主要有以下两种使用意图: + +| 场景 | 描述 | 对应训练子集 | +| ---------------- | ------------------------------------ | ------------------------- | +| **完全随机生成** | 用户不提供任何条件,直接生成完整地图 | 子集 C(随机墙壁 + Mask) | +| **墙壁辅助生成** | 用户手绘墙壁结构,模型填充非墙内容 | 子集 B / D | + +训练策略需针对这两种场景显式建模,通过多子集混合训练使单一模型同时支持两种工作模式。 + +--- + +## 训练策略 + +### 多子集混合训练 + +每个 batch 从四个子集中按比例采样,各子集对模型能力的针对性不同: + +#### 子集 A:标准 MaskGIT 训练(比例 0.4 ~ 0.6) + +保留原有 MaskGIT 掩码范式,随机遮盖部分 tile,训练模型基本的补全能力。 + +``` +原始地图 ──► 随机 Mask 部分 tile ──► MaskGIT 预测被遮盖 tile +``` + +- 掩码策略:沿用现有 `MapMask`(随机掩码 + 分块掩码 + 形态学变换) +- 作用:维持模型的基础生成质量,防止针对性训练导致退化 + +#### 子集 B:墙壁条件生成(比例 0.1 ~ 0.3) + +去除地图中所有非墙 tile(保留墙壁和空地),让模型在给定墙壁结构下生成其余内容。 + +``` +原始地图 ──► 清除所有非墙 tile(保留 空地/墙壁)──► MaskGIT 生成非墙内容 +``` + +- 输入:仅含墙壁(tile=1)和空地(tile=0)的地图 +- 目标:模型输出完整的地图(含怪物、道具、门、钥匙等) +- 对应场景:**用户手绘墙壁 → 模型填充** + +#### 子集 C:随机条件生成(比例 0.1 ~ 0.3) + +在子集 B 的基础上,进一步对墙壁进行随机 Mask,仅保留部分墙壁作为初始条件,引入随机性以支持完全随机生成场景。 + +``` +原始地图 ──► 清除非墙 tile ──► 随机 Mask 部分墙壁 tile ──► MaskGIT 生成 +``` + +- 输入:随机保留的稀疏墙壁片段 +- 目标:生成完整地图(含墙壁补全 + 非墙内容填充) +- 对应场景:**完全随机生成**(墙壁本身也需要生成) +- 注:z 的随机采样在此子集中发挥关键作用,控制生成风格多样性 + +#### 子集 D:入口条件生成(比例 0.1 ~ 0.2) + +在子集 B / C 的基础上额外保留入口 tile(tile=10),训练模型在给定入口位置时调整生成结果。 + +``` +原始地图 ──► 清除非墙非入口 tile(保留 空地/墙壁/入口)──► MaskGIT 生成 +``` + +- 输入:墙壁结构 + 入口位置 +- 目标:生成与入口位置语义一致的完整地图 +- 对应场景:**用户指定入口 → 模型生成与之匹配的布局** +- 注:单独设立此子集的原因是入口对地图整体拓扑结构有较强约束,需要专项强化 + +### 各子集比例汇总 + +| 子集 | 训练任务 | 建议比例 | +| ---- | -------------------------------- | --------- | +| A | 标准 MaskGIT(随机掩码补全) | 0.4 ~ 0.6 | +| B | 墙壁条件生成(保留全部墙壁) | 0.1 ~ 0.3 | +| C | 随机条件生成(随机保留部分墙壁) | 0.1 ~ 0.3 | +| D | 入口条件生成(保留墙壁 + 入口) | 0.1 ~ 0.2 | + +各子集比例之和为 1,可作为 `dataset.py` 中采样权重使用。 + +### 联合训练流程 + +将 VQ-VAE 编码器与改造后的 MaskGIT 一起训练: + +``` +真实地图 ──► VQ-VAE 编码器 ──► z(离散) + │ +按子集构造的掩码地图 + z ──► MaskGIT ──► 预测 logits ──► 交叉熵损失 + │ + ▼ + VQ 损失(commit + uniform) +``` + +总损失: + +$$\mathcal{L} = \mathcal{L}_{CE}(\text{MaskGIT}) + \beta \cdot \mathcal{L}_{commit} + \gamma \cdot \mathcal{L}_{uniform}$$ + +### Dropout z(提升鲁棒性) + +训练时以一定概率(如 10-20%)将 z 替换为随机采样的码字,模拟推理时随机采样 z 的情况,避免模型过度依赖 z 的精确语义。在子集 C(完全随机生成)中此机制尤为重要。 + +--- + +## 推理流程 + +### 场景一:完全随机生成 + +``` +1. 随机从 [0, K) 均匀采样 L 个 index +2. 从 codebook 查表,得到 z [L, d_z] +3. 初始化地图序列:以少量随机墙壁(比例 0.02 ~ 0.1)作为初始条件,其余位置填充 MASK token +4. 迭代 MaskGIT 解码(cosine schedule): + a. 输入掩码地图 + z → 预测所有 MASK 位置的 logits + b. 取置信度最高的 N 个 token 进行 unmask + c. 重复直到无 MASK token +5. 输出最终地图 +``` + +**初始墙壁比例说明**:全 MASK 初始化会导致生成多样性不佳(即使使用 top-k 采样也难以改善)。引入少量随机墙壁(2%~10%)作为种子,可有效提升多样性;但比例不宜过高,否则随机放置的墙壁可能产生拓扑结构冲突,影响生成质量。 + +### 场景二:墙壁辅助生成(用户手绘墙壁) + +``` +1. 用户提供墙壁布局(tile=1 的位置),其余位置填充 MASK token +2. 随机从 [0, K) 均匀采样 L 个 index,得到 z +3. 迭代 MaskGIT 解码: + a. 输入带已知墙壁的掩码地图 + z → 预测 MASK 位置的 logits + b. 已知墙壁位置不参与 unmask,保持不变 + c. 取置信度最高的 N 个 MASK token 进行 unmask + d. 重复直到无 MASK token +4. 输出最终地图 +``` + +### 场景三:入口指定生成 + +``` +与场景二类似,但初始条件同时固定墙壁位置和入口位置(tile=10) +其余 MASK 位置由 MaskGIT 迭代解码填充 +``` + +--- + +## 超参数(待确定) + +| 参数 | 说明 | 建议初始值 | +| -------------- | --------------------- | ---------- | +| `K` | codebook 大小 | 8 ~ 32 | +| `L` | 码字序列长度(方案B) | 1 ~ 4 | +| `d_z` | 码字嵌入维度 | 64 ~ 128 | +| `β` | commit loss 权重 | 0.25 | +| `γ` | uniform loss 权重 | 0.1 | +| z dropout 概率 | 训练时随机 z 的比例 | 0.1 ~ 0.2 | + +地图固定为 13×13,z 定位为风格与多样性控制信号,K 无需过大(8~32 足够),L 也应保持较小以避免过度约束生成。 + +--- + +## 已决定事项 + +| 决策点 | 结论 | +| -------------- | -------------------------------------------------------------------------- | +| VQ-VAE 解码器 | **不加**,直接与 MaskGIT 联合训练,z 语义由生成损失反向传播决定 | +| 编码器架构方案 | **方案 B**(序列码字),全 Transformer 结构;若训练不稳定可退回方案 A | +| z 注入方式 | **Cross-Attention**(方式一) | +| 均匀正则化策略 | **熵最大化**;若出现 collapse 再引入 EMA | +| 分层 VQ | **单层**,13×13 地图无需多级量化 | +| 标量 cond 保留 | **暂不使用**;训练集标量分布严格,推理阶段难以生成合理值,后续可再考虑加入 | + +## 待探索事项 + +- 合适的 K、L 取值(建议从 K=16, L=2 开始实验) +- z dropout 的最优概率 +- 若后续 codebook collapse:引入 EMA 更新 + 重置机制 +- 若后续需要更细粒度控制:加入标量 cond(需对推理侧标量做随机扰动处理) + +--- + +## 下一步行动 + +- [x] 确定编码方案(方案 B,序列码字) +- [x] 确定 z 注入方式(Cross-Attention) +- [x] 确定是否需要 VQ-VAE 解码器(不加) +- [x] 确定是否保留标量 cond(暂不使用) +- [x] 确定训练子集划分与各场景比例 +- [x] 实现 VQ-VAE 编码器模块(Transformer + VQ) +- [ ] 改造 `GinkaMaskGIT.forward()` 接受 z,替换热力图分支 +- [ ] 实现四种子集采样逻辑(`dataset.py` 新增多子集 Dataset 或采样权重) +- [ ] 实现子集 B/C/D 的输入构造函数(按规则清除 tile、保留墙壁/入口) +- [ ] 编写联合训练脚本,整合 VQ 损失与 MaskGIT 交叉熵损失 diff --git a/ginka/dataset.py b/ginka/dataset.py index 804ead1..beb4b6a 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -229,4 +229,204 @@ class GinkaJointDataset(Dataset): "target_map": target_map, "target_heatmap": target_heatmap, "cond_heatmap": cond_heatmap - } \ No newline at end of file + } + + +class GinkaVQDataset(Dataset): + """ + 用于 VQ-VAE + MaskGIT 联合训练的多子集数据集。 + + 每次 __getitem__ 按权重随机选取以下四种子集之一: + A (standard): 标准 MaskGIT 随机掩码,随机遮盖部分 tile + B (wall-only): 仅保留 wall(1) + floor(0),其余全部替换为 MASK(15) + C (wall-random): 在 B 基础上,再随机 mask 部分 wall tile + D (wall+entry): 仅保留 wall(1) + floor(0) + entrance(10),其余全部替换为 MASK(15) + + 返回 dict: + raw_map: LongTensor [H*W] 完整原始地图(供 VQ-VAE 编码) + masked_map: LongTensor [H*W] MaskGIT 输入(被 mask 的位置 = 15) + target_map: LongTensor [H*W] CE loss ground truth(等同 raw_map) + subset: str 子集标识,供调试/统计用 + """ + + FLOOR = 0 + WALL = 1 + ENTRANCE = 10 + MASK_ID = 15 + + def __init__( + self, + data_path: str, + subset_weights: tuple = (0.5, 0.2, 0.2, 0.1), + wall_mask_ratio: float = 0.3, + ): + """ + Args: + data_path: JSON 数据文件路径 + subset_weights: 子集 (A, B, C, D) 的采样权重,自动归一化 + wall_mask_ratio: Subset C 中额外随机 mask 的 wall tile 比例上限 + (每次从 [0, wall_mask_ratio] 均匀采样实际比例) + """ + self.data = load_data(data_path) + self.wall_mask_ratio = wall_mask_ratio + + # 累积权重,用于快速随机子集选择 + total_w = sum(subset_weights) + normalized = [x / total_w for x in subset_weights] + self.subset_cumw = [sum(normalized[:i + 1]) for i in range(len(normalized))] + + def __len__(self): + return len(self.data) + + # ------------------------------------------------------------------ + # 内联随机掩码生成(避免 scipy 的 NumPy 版本兼容问题) + # ------------------------------------------------------------------ + @staticmethod + def _sample_mask_ratio(min_r=0.05, max_r=1.0) -> float: + """用 Beta(2,2) 分布采样掩码比例,集中在中间值。""" + r = np.random.beta(2, 2) + return min_r + (max_r - min_r) * r + + @staticmethod + def _random_mask(h: int, w: int) -> np.ndarray: + """纯随机掩码,返回 [H*W] bool。""" + ratio = GinkaVQDataset._sample_mask_ratio() + total = h * w + idx = np.random.choice(total, int(total * ratio), replace=False) + mask = np.zeros(total, dtype=bool) + mask[idx] = True + return mask + + @staticmethod + def _block_mask(h: int, w: int) -> np.ndarray: + """矩形分块随机掩码,返回 [H*W] bool。""" + ratio = GinkaVQDataset._sample_mask_ratio() + max_block = max(2, min(h, w) // 2) + target = int(h * w * ratio) + mask = np.zeros((h, w), dtype=bool) + while mask.sum() < target: + bh = np.random.randint(2, max_block + 1) + bw = np.random.randint(2, max_block + 1) + x = np.random.randint(0, max(1, h - bh + 1)) + y = np.random.randint(0, max(1, w - bw + 1)) + mask[x:x + bh, y:y + bw] = True + return mask.reshape(-1) + + def _std_mask(self, h: int, w: int) -> np.ndarray: + """标准 MaskGIT 掩码:随机选择纯随机或分块策略。""" + if random.random() < 0.5: + return self._random_mask(h, w) + else: + return self._block_mask(h, w) + + # ------------------------------------------------------------------ + + def _augment(self, arr: np.ndarray) -> np.ndarray: + """随机旋转 / 翻转数据增强,返回新 array。""" + if np.random.rand() > 0.5: + k = np.random.randint(1, 4) + arr = np.rot90(arr, k).copy() + if np.random.rand() > 0.5: + arr = np.fliplr(arr).copy() + if np.random.rand() > 0.5: + arr = np.flipud(arr).copy() + return arr + + def _choose_subset(self) -> str: + r = random.random() + if r < self.subset_cumw[0]: + return 'A' + elif r < self.subset_cumw[1]: + return 'B' + elif r < self.subset_cumw[2]: + return 'C' + else: + return 'D' + + def _apply_subset(self, raw: np.ndarray, subset: str) -> np.ndarray: + """ + 根据子集类型生成 masked_map。 + + Args: + raw: [H, W] int64 完整原始地图 + subset: 'A' | 'B' | 'C' | 'D' + + Returns: + [H*W] int64,被遮盖位置值为 MASK_ID(15) + """ + H, W = raw.shape + + if subset == 'A': + # 标准随机 mask:纯随机或分块策略 + mask = self._std_mask(H, W) # [H*W] bool + flat = raw.reshape(-1).copy() + flat[mask] = self.MASK_ID + return flat + + elif subset == 'B': + # 仅保留 floor(0) 和 wall(1) + flat = raw.reshape(-1).copy() + keep = (flat == self.FLOOR) | (flat == self.WALL) + flat[~keep] = self.MASK_ID + return flat + + elif subset == 'C': + # Subset B + 随机 mask 部分 wall + flat = raw.reshape(-1).copy() + keep = (flat == self.FLOOR) | (flat == self.WALL) + flat[~keep] = self.MASK_ID + + wall_idx = np.where(flat == self.WALL)[0] + if len(wall_idx) > 0: + ratio = random.random() * self.wall_mask_ratio + n = max(1, int(len(wall_idx) * ratio)) + chosen = np.random.choice(wall_idx, n, replace=False) + flat[chosen] = self.MASK_ID + return flat + + else: # D + # 仅保留 floor(0)、wall(1) 和 entrance(10) + flat = raw.reshape(-1).copy() + keep = ( + (flat == self.FLOOR) + | (flat == self.WALL) + | (flat == self.ENTRANCE) + ) + flat[~keep] = self.MASK_ID + return flat + + def __getitem__(self, idx): + item = self.data[idx] + + raw_np = self._augment(np.array(item['map'], dtype=np.int64)) # [H, W] + subset = self._choose_subset() + masked_np = self._apply_subset(raw_np, subset) # [H*W] + raw_flat = raw_np.reshape(-1) # [H*W] + + return { + "raw_map": torch.LongTensor(raw_flat), # VQ-VAE 编码器输入 + "masked_map": torch.LongTensor(masked_np), # MaskGIT 输入 + "target_map": torch.LongTensor(raw_flat.copy()), # CE loss ground truth + "subset": subset, # 调试/统计用 + } + + +if __name__ == "__main__": + import os + data_path = os.path.join(os.path.dirname(__file__), '..', 'data', 'ginka-dataset.json') + ds = GinkaVQDataset(data_path) + print(f"数据集大小: {len(ds)}") + + subset_count = {'A': 0, 'B': 0, 'C': 0, 'D': 0} + for i in range(200): + sample = ds[i % len(ds)] + subset_count[sample['subset']] += 1 + + raw = sample['raw_map'] + masked = sample['masked_map'] + target = sample['target_map'] + print(f"raw_map shape={raw.shape}, dtype={raw.dtype}") + print(f"masked_map shape={masked.shape}, dtype={masked.dtype}") + print(f"target_map shape={target.shape}, dtype={target.dtype}") + print(f"被 mask 的位置数: {(masked == 15).sum().item()} / {masked.numel()}") + print(f"\n200 次采样子集分布: {subset_count}") \ No newline at end of file diff --git a/ginka/maskGIT/maskGIT.py b/ginka/maskGIT/maskGIT.py index aeace76..d41d4fe 100644 --- a/ginka/maskGIT/maskGIT.py +++ b/ginka/maskGIT/maskGIT.py @@ -15,9 +15,15 @@ class Transformer(nn.Module): num_layers=num_layers ) - def forward(self, x): - # x: [B, L, d_model] - m = self.encoder(x) - out = self.decoder(x, m) + def forward(self, x, memory=None): + # x: [B, S, d_model] 地图 token 序列 + # memory: [B, L, d_model] 可选的 z 投影,用于 cross-attention + # 若 memory 为 None,则退化为原始自编解码行为(向后兼容) + enc_out = self.encoder(x) + if memory is not None: + # encoder 输出作为 query,z 作为 key/value + out = self.decoder(enc_out, memory) + else: + out = self.decoder(x, enc_out) return out \ No newline at end of file diff --git a/ginka/maskGIT/model.py b/ginka/maskGIT/model.py index 57f60a0..67b9364 100644 --- a/ginka/maskGIT/model.py +++ b/ginka/maskGIT/model.py @@ -1,88 +1,117 @@ import time import torch import torch.nn as nn -import torch.nn.functional as F from ..utils import print_memory -from .cond import GinkaMaskGITCond from .maskGIT import Transformer + class GinkaMaskGIT(nn.Module): + """ + 改造后的 MaskGIT 地图生成模型。 + + 以掩码地图序列和 VQ-VAE 输出的离散隐变量 z 为输入, + 通过 Transformer encoder-decoder 结构预测被遮盖位置的 tile 类别。 + + z 通过 cross-attention 注入到 Transformer decoder, + 作为风格/多样性控制信号,而非结构重建指导。 + """ + def __init__( - self, num_classes=16, heatmap_channel=4, d_model=256, - dim_ff=512, nhead=8, num_layers=4, map_size=13*13 + self, + num_classes: int = 16, + d_model: int = 192, + d_z: int = 64, + dim_ff: int = 512, + nhead: int = 8, + num_layers: int = 4, + map_size: int = 13 * 13, + z_dropout: float = 0.1, ): + """ + Args: + num_classes: tile 类别数(含 MASK token=15) + d_model: Transformer 内部维度 + d_z: VQ-VAE 码字嵌入维度,需与 GinkaVQVAE.d_z 一致 + dim_ff: 前馈网络隐层维度 + nhead: 注意力头数 + num_layers: Transformer 层数 + map_size: 地图 token 总数(H * W) + z_dropout: 训练时随机替换 z 为随机码字的概率(提升鲁棒性) + """ super().__init__() - + self.z_dropout = z_dropout + + # Tile 嵌入 + 位置编码 self.tile_embedding = nn.Embedding(num_classes, d_model) - self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model)) - - cond_channels = [d_model // 4, d_model // 2, d_model] - self.cond_encoder = GinkaMaskGITCond(input_channel=heatmap_channel, channels=cond_channels) - self.gate_encoder = nn.Sequential( - nn.Conv2d(cond_channels[2], cond_channels[2], 3, padding=1, padding_mode="replicate"), - nn.BatchNorm2d(cond_channels[2]), - nn.GELU() + self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model) * 0.02) + + # z 投影:将 VQ 码字从 d_z 维映射到 d_model 维,供 cross-attention 使用 + self.z_proj = nn.Sequential( + nn.Linear(d_z, d_model), + nn.LayerNorm(d_model), ) - self.cond_gate = nn.Sequential( - nn.Linear(cond_channels[2] * 2, cond_channels[2]), - nn.LayerNorm(cond_channels[2]), - nn.Dropout(0.3), - nn.GELU(), - - nn.Linear(cond_channels[2], cond_channels[2]) + + # Transformer:encoder 做 map token 自注意力,decoder 做与 z 的 cross-attention + self.transformer = Transformer( + d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers ) - - self.transformer = Transformer(d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers) - - self.output_fc = nn.Sequential( - nn.Linear(d_model, num_classes) - ) - - def forward(self, map: torch.Tensor, heatmap: torch.Tensor): - # map: [B, H * W] - # heatmap: [B, C, H, W] - # output: [B, H * W, num_classes] - heatmap = self.cond_encoder(heatmap) # [B, d_model, H, W] - B, C, H, W = heatmap.shape - heatmap_gate = self.gate_encoder(heatmap) - heatmap_mean = F.avg_pool2d(heatmap_gate, (H, W)) # [B, d_model, 1, 1] - heatmap_max = F.max_pool2d(heatmap_gate, (H, W)) # [B, d_model, 1, 1] - gate_input = torch.cat([heatmap_mean, heatmap_max], dim=1).squeeze(2).squeeze(2) - gate = self.cond_gate(gate_input) # [B, d_model] - - heatmap = heatmap * torch.sigmoid(gate).unsqueeze(2).unsqueeze(2) - heatmap = heatmap.view(B, C, H * W).permute(0, 2, 1) - x = self.tile_embedding(map) + heatmap - x = x + self.pos_embedding - x = self.transformer(x) - - logits = self.output_fc(x) - + + self.output_fc = nn.Linear(d_model, num_classes) + + def forward(self, map: torch.Tensor, z: torch.Tensor) -> torch.Tensor: + """ + Args: + map: [B, H*W] 掩码后的地图 token 序列(MASK token = 15) + z: [B, L, d_z] VQ-VAE 量化后的离散隐变量 + + Returns: + logits: [B, H*W, num_classes] + """ + # z dropout:训练时以一定概率将 z 替换为随机均匀噪声, + # 模拟推理时随机采样 z 的分布,避免模型过拟合于精确的 z 语义 + if self.training and self.z_dropout > 0: + mask = torch.rand(z.shape[0], 1, 1, device=z.device) < self.z_dropout + rand_z = torch.randn_like(z) + z = torch.where(mask, rand_z, z) + + # 投影 z 到 d_model 维度 + z_mem = self.z_proj(z) # [B, L, d_model] + + # tile embedding + 位置编码 + x = self.tile_embedding(map) # [B, H*W, d_model] + x = x + self.pos_embedding # [B, H*W, d_model] + + # Transformer:encoder 做 map 自注意力,decoder cross-attend z + x = self.transformer(x, memory=z_mem) # [B, H*W, d_model] + + logits = self.output_fc(x) # [B, H*W, num_classes] return logits - + + if __name__ == "__main__": device = torch.device("cpu") - - map = torch.randint(0, 16, [1, 169]).to(device) - heatmap = torch.rand(1, 4, 13, 13).to(device) - - # 初始化模型 - model = GinkaMaskGIT().to(device) - - print_memory("初始化后") - - # 前向传播 - start = time.perf_counter() - output = model(map, heatmap) - end = time.perf_counter() - - print_memory("前向传播后") - - print(f"推理耗时: {end - start}") - print(f"输出形状: output={output.shape}") - print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}") - print(f"Condition Encoder parameters: {sum(p.numel() for p in model.cond_encoder.parameters())}") - print(f"Condition Gate parameters: {sum(p.numel() for p in model.cond_gate.parameters())}") - print(f"MaskGIT parameters: {sum(p.numel() for p in model.transformer.parameters())}") - print(f"Output parameters: {sum(p.numel() for p in model.output_fc.parameters())}") - print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") + + model = GinkaMaskGIT( + num_classes=16, + d_model=192, + d_z=64, + dim_ff=512, + nhead=8, + num_layers=4, + map_size=13 * 13, + ).to(device) + + total_params = sum(p.numel() for p in model.parameters()) + print(f"总参数量: {total_params:,} ({total_params / 1e6:.3f}M)") + for name, module in model.named_children(): + n = sum(p.numel() for p in module.parameters()) + print(f" {name}: {n:,}") + + map_input = torch.randint(0, 16, (4, 13 * 13)).to(device) # [B=4, 169] + z_input = torch.randn(4, 2, 64).to(device) # [B=4, L=2, d_z=64] + + model.train() + logits = model(map_input, z_input) + print(f"\nlogits shape: {logits.shape}") # [4, 169, 16] + + print_memory(device, "前向传播后") diff --git a/ginka/train_vq.py b/ginka/train_vq.py new file mode 100644 index 0000000..fe0bf0b --- /dev/null +++ b/ginka/train_vq.py @@ -0,0 +1,520 @@ +""" +联合训练脚本:VQ-VAE + MaskGIT + +总损失 = L_CE(MaskGIT 重建损失)+ beta * L_commit + gamma * L_entropy + +验证阶段对四种子集(A/B/C/D)分别输出图片, +每条样本额外采样 N_Z_SAMPLES 个随机 z, +便于直观对比同条件不同 z 下的生成差异。 + +用法示例: + python -m ginka.train_vq + python -m ginka.train_vq --resume True --state result/joint/joint-10.pth +""" + +import argparse +import math +import os +import sys +import random +from datetime import datetime + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +import torch.optim as optim +from tqdm import tqdm +from torch.utils.data import DataLoader + +from .vqvae.model import GinkaVQVAE +from .maskGIT.model import GinkaMaskGIT +from .dataset import GinkaVQDataset +from shared.image import matrix_to_image_cv + +# --------------------------------------------------------------------------- +# 超参数 +# --------------------------------------------------------------------------- +BATCH_SIZE = 64 +NUM_CLASSES = 16 +MASK_TOKEN = 15 +GENERATE_STEP = 12 # 推理时 MaskGIT 迭代步数 +MAP_SIZE = 13 * 13 +MAP_H = MAP_W = 13 +LABEL_SMOOTHING = 0.0 + +# VQ-VAE 超参 +VQ_L = 2 # summary token 数量(即 z 的序列长度) +VQ_K = 16 # codebook 大小 +VQ_D_Z = 64 # codebook 嵌入维度 +VQ_D_MODEL= 128 +VQ_NHEAD = 4 +VQ_LAYERS = 2 +VQ_DIM_FF = 256 +VQ_BETA = 0.25 # commit loss 权重 +VQ_GAMMA = 0.1 # entropy loss 权重 + +# MaskGIT 超参 +MG_D_MODEL = 192 +MG_NHEAD = 8 +MG_LAYERS = 4 +MG_DIM_FF = 512 +MG_Z_DROPOUT= 0.15 # 训练时以此概率把 z 替换为随机噪声 + +# 验证时对每条样本额外采样的 z 数量(0 = 只用真实 z) +N_Z_SAMPLES = 3 + +# 四个子集 A/B/C/D 的采样权重(训练集与验证集共用) +SUBSET_WEIGHTS = (0.5, 0.2, 0.2, 0.1) + +# --------------------------------------------------------------------------- +# 设备 +# --------------------------------------------------------------------------- +device = torch.device( + "cuda:1" if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() + else "cpu" +) + +os.makedirs("result/joint", exist_ok=True) +os.makedirs("result/joint_img", exist_ok=True) + +disable_tqdm = not sys.stdout.isatty() + +# --------------------------------------------------------------------------- +# 参数解析 +# --------------------------------------------------------------------------- +def parse_arguments(): + parser = argparse.ArgumentParser(description="VQ-VAE + MaskGIT 联合训练") + parser.add_argument("--resume", type=bool, default=False) + parser.add_argument("--state", type=str, default="result/joint/joint-10.pth", + help="续训时加载的检查点路径") + parser.add_argument("--train", type=str, default="data/ginka-dataset.json") + parser.add_argument("--validate", type=str, default="data/ginka-eval.json") + parser.add_argument("--epochs", type=int, default=100) + parser.add_argument("--checkpoint", type=int, default=5, + help="每隔多少 epoch 保存检查点并验证") + parser.add_argument("--load_optim", type=bool, default=True) + return parser.parse_args() + +# --------------------------------------------------------------------------- +# MaskGIT 推理(cosine schedule 迭代解码) +# --------------------------------------------------------------------------- +@torch.no_grad() +def maskgit_generate( + model_mg: GinkaMaskGIT, + z: torch.Tensor, + steps: int = GENERATE_STEP, + init_map: torch.Tensor = None, +) -> torch.Tensor: + """ + 迭代生成地图(cosine schedule unmasking)。 + + Args: + model_mg: GinkaMaskGIT(eval 模式) + z: [B, L, d_z] 条件 z + steps: 解码步数 + init_map: [B, MAP_SIZE] 可选初始地图;非 MASK 位置在生成过程中保持固定。 + 为 None 时从全 MASK 开始(自由生成)。 + + Returns: + map_out: [B, MAP_SIZE] + """ + B = z.shape[0] + if init_map is None: + map_seq = torch.full((B, MAP_SIZE), MASK_TOKEN, device=device) + else: + map_seq = init_map.clone().to(device) + + # 记录初始 MASK 位置,这些位置才需要生成 + generatable = (map_seq == MASK_TOKEN) # [B, S] bool + + for step in range(steps): + if not generatable.any(): + break + + logits = model_mg(map_seq, z) # [B, S, C] + probs = F.softmax(logits, dim=-1) + dist = torch.distributions.Categorical(probs) + sampled = dist.sample() # [B, S] + + # 计算置信度;固定位置设为 +inf(确保不会被选为“继续保持 MASK”) + confidences = torch.gather(probs, -1, sampled.unsqueeze(-1)).squeeze(-1) + confidences = confidences.masked_fill(~generatable, float('inf')) + + ratio = math.cos(((step + 1) / steps) * math.pi / 2) + new_map = map_seq.clone() + + for b in range(B): + n_gen = int(generatable[b].sum().item()) + n_keep = int(ratio * n_gen) # 本步仍保持 MASK 的位置数 + + if n_keep > 0: + _, keep_idx = torch.topk(confidences[b], k=n_keep, largest=False) + pred_b = sampled[b].clone() + pred_b[keep_idx] = MASK_TOKEN + new_map[b] = torch.where(generatable[b], pred_b, map_seq[b]) + else: + new_map[b] = torch.where(generatable[b], sampled[b], map_seq[b]) + + map_seq = new_map + + return map_seq + +# --------------------------------------------------------------------------- +# 可视化工具 +# --------------------------------------------------------------------------- +def make_map_image(map_flat: torch.Tensor, tile_dict: dict) -> np.ndarray: + """将 [MAP_SIZE] 的 tensor 转成 RGB 图片(numpy)。""" + arr = map_flat.cpu().numpy().reshape(MAP_H, MAP_W) + return matrix_to_image_cv(arr, tile_dict) + + +def hstack_images(imgs: list, gap: int = 4, color=(255, 255, 255)) -> np.ndarray: + """将多张等高图片横向拼接,之间插入白色竖线。""" + H = imgs[0].shape[0] + vline = np.full((H, gap, 3), color, dtype=np.uint8) + result = imgs[0] + for img in imgs[1:]: + result = np.concatenate([result, vline, img], axis=1) + return result + + +def label_image(img: np.ndarray, text: str, font_scale: float = 0.45) -> np.ndarray: + """在图片顶部加一行文字标签(就地修改并返回)。""" + bar_h = 16 + bar = np.full((bar_h, img.shape[1], 3), (40, 40, 40), dtype=np.uint8) + cv2.putText( + bar, text, (2, bar_h - 3), + cv2.FONT_HERSHEY_SIMPLEX, font_scale, + (200, 200, 200), 1, cv2.LINE_AA + ) + return np.concatenate([bar, img], axis=0) + + +def make_random_wall_seed(ratio_min: float = 0.02, ratio_max: float = 0.08) -> torch.Tensor: + """ + 在全 MASK 地图上随机放置少量墙壁作为推理种子,用于完全随机生成场景。 + + Returns: + [1, MAP_SIZE] MASK=15 背景 + 随机置少量墙壁(tile=1) + """ + ratio = random.uniform(ratio_min, ratio_max) + n_wall = max(2, int(MAP_SIZE * ratio)) + seed = torch.full((1, MAP_SIZE), MASK_TOKEN, dtype=torch.long, device=device) + idx = torch.randperm(MAP_SIZE)[:n_wall] + seed[0, idx] = 1 # wall + return seed + +@torch.no_grad() +def validate( + model_vq: GinkaVQVAE, + model_mg: GinkaMaskGIT, + dataloader_val: DataLoader, + tile_dict: dict, + epoch: int, +): + """ + 验证函数:计算 val loss 并输出 5 类推理场景的对比图。 + + 场景说明(按 epoch 建立子文件夹,避免图片堆积): + 场景1 (scene1_completion) : 子集 A,标准随机掩码补全 + 列: ground truth | masked input | z_real pred | z_real gen | z_rand×N + 场景2 (scene2_wall) : 子集 B,仅墙壁+空地 → 生成完整地图 + 列: ground truth | wall-only input | z_real gen | z_rand×N + 场景3 (scene3_sparse) : 子集 C,稀疏墙壁条件 → 生成完整地图 + 列: ground truth | sparse wall input | z_real gen | z_rand×N + 场景4 (scene4_entrance) : 子集 D,墙壁+入口 → 生成完整地图 + 列: ground truth | wall+entrance input | z_real gen | z_rand×N + 场景5 (scene5_random) : 无数据集参照,随机稀疏墙壁种子 → 完全随机生成 + 列: random seed | z_rand×(N+1) + """ + model_vq.eval() + model_mg.eval() + + # 按 epoch 建立独立子文件夹,保留每次验证结果方便回溯 + epoch_dir = f"result/joint_img/e{epoch:04d}" + os.makedirs(epoch_dir, exist_ok=True) + + val_loss_total = 0.0 + val_steps = 0 + captured = {s: None for s in ('A', 'B', 'C', 'D')} + + # ── 计算 val loss + 捕获各子集样本 ────────────────────────────────────── + for batch in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm): + raw_map = batch["raw_map"].to(device) # [B, 169] + masked_map = batch["masked_map"].to(device) # [B, 169] + target_map = batch["target_map"].to(device) # [B, 169] + subsets = batch["subset"] # list of str + B = raw_map.shape[0] + + z_q, _, vq_loss = model_vq(raw_map) + logits = model_mg(masked_map, z_q) + mask = (masked_map == MASK_TOKEN) + + ce_loss = F.cross_entropy( + logits.permute(0, 2, 1), target_map, + reduction='none', label_smoothing=LABEL_SMOOTHING + ) + masked_ce = (ce_loss * mask).sum() / (mask.sum() + 1e-6) + val_loss_total += (masked_ce + vq_loss).item() + val_steps += 1 + + for i in range(B): + s = subsets[i] + if captured[s] is None: + captured[s] = { + "raw": raw_map[i:i+1].clone(), + "masked": masked_map[i:i+1].clone(), + "z_q": z_q[i:i+1].clone(), + } + + if all(v is not None for v in captured.values()): + break + + # ── 公共辅助:对给定条件地图随机采样 n 次 z 并迭代生成 ────────────────── + def _rand_gens(cond_map, n): + imgs = [] + for i in range(n): + z_r = model_vq.sample(1, device) + gen = maskgit_generate(model_mg, z_r, init_map=cond_map) + imgs.append(label_image(make_map_image(gen[0], tile_dict), f"z_rand_{i + 1}")) + return imgs + + # ── 场景1:标准掩码补全(子集 A)───────────────────────────────────────── + if captured['A'] is not None: + cap = captured['A'] + raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q'] + + real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth") + cond_img = label_image(make_map_image(cond[0], tile_dict), "masked input") + + # 单步 argmax 预测(观察模型对掩码位置的瞬时判断) + pred = model_mg(cond, z_q).argmax(dim=-1)[0] + pred_img = label_image(make_map_image(pred, tile_dict), "z_real pred") + + # 迭代生成(从掩码输入出发,真实 z) + gen_real = maskgit_generate(model_mg, z_q, init_map=cond) + gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") + + row = [real_img, cond_img, pred_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) + cv2.imwrite(f"{epoch_dir}/scene1_completion.png", hstack_images(row)) + + # ── 场景2:墙壁辅助生成(子集 B)───────────────────────────────────────── + if captured['B'] is not None: + cap = captured['B'] + raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q'] + + real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth") + cond_img = label_image(make_map_image(cond[0], tile_dict), "wall-only input") + gen_real = maskgit_generate(model_mg, z_q, init_map=cond) + gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") + + row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) + cv2.imwrite(f"{epoch_dir}/scene2_wall.png", hstack_images(row)) + + # ── 场景3:稀疏墙壁条件生成(子集 C)──────────────────────────────────── + if captured['C'] is not None: + cap = captured['C'] + raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q'] + + real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth") + cond_img = label_image(make_map_image(cond[0], tile_dict), "sparse wall input") + gen_real = maskgit_generate(model_mg, z_q, init_map=cond) + gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") + + row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) + cv2.imwrite(f"{epoch_dir}/scene3_sparse.png", hstack_images(row)) + + # ── 场景4:墙壁+入口条件生成(子集 D)─────────────────────────────────── + if captured['D'] is not None: + cap = captured['D'] + raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q'] + + real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth") + cond_img = label_image(make_map_image(cond[0], tile_dict), "wall+entrance input") + gen_real = maskgit_generate(model_mg, z_q, init_map=cond) + gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") + + row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) + cv2.imwrite(f"{epoch_dir}/scene4_entrance.png", hstack_images(row)) + + # ── 场景5:完全随机生成(无数据集参照)────────────────────────────────── + # 随机稀疏墙壁种子,对应推理时"不提供任何条件,直接生成"的场景 + rand_seed = make_random_wall_seed() # [1, MAP_SIZE] + seed_img = label_image(make_map_image(rand_seed[0], tile_dict), "random seed") + row = [seed_img] + _rand_gens(rand_seed, N_Z_SAMPLES + 1) # 多采一个 z 展示多样性 + cv2.imwrite(f"{epoch_dir}/scene5_random.png", hstack_images(row)) + + avg_val_loss = val_loss_total / max(val_steps, 1) + return avg_val_loss + +# --------------------------------------------------------------------------- +# 主训练函数 +# --------------------------------------------------------------------------- +def train(): + print(f"Using device: {device}") + args = parse_arguments() + + # ---- 模型 ---- + model_vq = GinkaVQVAE( + num_classes=NUM_CLASSES, + L=VQ_L, K=VQ_K, d_z=VQ_D_Z, + d_model=VQ_D_MODEL, nhead=VQ_NHEAD, + num_layers=VQ_LAYERS, dim_ff=VQ_DIM_FF, + map_size=MAP_SIZE, + beta=VQ_BETA, gamma=VQ_GAMMA, + ).to(device) + + model_mg = GinkaMaskGIT( + num_classes=NUM_CLASSES, + d_model=MG_D_MODEL, d_z=VQ_D_Z, + dim_ff=MG_DIM_FF, nhead=MG_NHEAD, + num_layers=MG_LAYERS, + map_size=MAP_SIZE, + z_dropout=MG_Z_DROPOUT, + ).to(device) + + vq_params = sum(p.numel() for p in model_vq.parameters()) + mg_params = sum(p.numel() for p in model_mg.parameters()) + print(f"VQ-VAE 参数量: {vq_params:,} ({vq_params/1e6:.3f}M)") + print(f"MaskGIT 参数量: {mg_params:,} ({mg_params/1e6:.3f}M)") + print(f"Total 参数量: {vq_params+mg_params:,} ({(vq_params+mg_params)/1e6:.3f}M)") + + # ---- 数据集 ---- + dataset_train = GinkaVQDataset( + args.train, + subset_weights=SUBSET_WEIGHTS, + ) + dataset_val = GinkaVQDataset( + args.validate, + subset_weights=SUBSET_WEIGHTS, + ) + dataloader_train = DataLoader( + dataset_train, batch_size=BATCH_SIZE, shuffle=True, + num_workers=0, pin_memory=(device.type == "cuda"), + ) + dataloader_val = DataLoader( + dataset_val, batch_size=8, shuffle=True, + num_workers=0, + ) + + # ---- 优化器(联合训练,两个模型共用一个 optimizer)---- + all_params = list(model_vq.parameters()) + list(model_mg.parameters()) + optimizer = optim.AdamW(all_params, lr=2e-4, weight_decay=1e-2) + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=args.epochs, eta_min=1e-6 + ) + + # ---- 续训 ---- + start_epoch = 0 + if args.resume: + ckpt = torch.load(args.state, map_location=device) + model_vq.load_state_dict(ckpt["vq_state"], strict=False) + model_mg.load_state_dict(ckpt["mg_state"], strict=False) + if args.load_optim and ckpt.get("optim_state") is not None: + optimizer.load_state_dict(ckpt["optim_state"]) + start_epoch = ckpt.get("epoch", 0) + print(f"从 epoch {start_epoch} 接续训练。") + + # ---- tile 贴图(用于验证可视化)---- + tile_dict = {} + for file in os.listdir("tiles"): + name = os.path.splitext(file)[0] + img = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED) + if img is not None: + tile_dict[name] = img + + # ---- 训练循环 ---- + for epoch in tqdm(range(start_epoch, start_epoch + args.epochs), + desc="Joint Training", disable=disable_tqdm): + model_vq.train() + model_mg.train() + + loss_total = 0.0 + ce_total = 0.0 + vq_loss_total = 0.0 + subset_stats = {'A': 0, 'B': 0, 'C': 0, 'D': 0} + + for batch in tqdm(dataloader_train, leave=False, + desc="Epoch Progress", disable=disable_tqdm): + raw_map = batch["raw_map"].to(device) # [B, 169] + masked_map = batch["masked_map"].to(device) # [B, 169] + target_map = batch["target_map"].to(device) # [B, 169] + + for s in batch["subset"]: + subset_stats[s] = subset_stats.get(s, 0) + 1 + + # ---- 前向传播 ---- + # 1. VQ-VAE 编码真实地图 → z_q + z_q, _, vq_loss = model_vq(raw_map) # z_q: [B, L, d_z] + + # 2. MaskGIT 以掩码地图 + z 预测原始 tile + logits = model_mg(masked_map, z_q) # [B, 169, C] + + # 3. 只对被 mask 的位置计算 CE loss + mask = (masked_map == MASK_TOKEN) # [B, 169] bool + ce_loss = F.cross_entropy( + logits.permute(0, 2, 1), target_map, + reduction='none', label_smoothing=LABEL_SMOOTHING + ) + masked_ce = (ce_loss * mask).sum() / (mask.sum() + 1e-6) + + # 4. 联合损失 + loss = masked_ce + vq_loss + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0) + optimizer.step() + + loss_total += loss.detach().item() + ce_total += masked_ce.detach().item() + vq_loss_total += vq_loss.detach().item() + + scheduler.step() + + n = len(dataloader_train) + tqdm.write( + f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + f"Epoch {epoch + 1:4d} | " + f"Loss {loss_total/n:.5f} " + f"CE {ce_total/n:.5f} " + f"VQ {vq_loss_total/n:.5f} | " + f"LR {scheduler.get_last_lr()[0]:.6f} | " + f"Subsets {subset_stats}" + ) + + # ---- 检查点 + 验证 ---- + if (epoch + 1) % args.checkpoint == 0: + ckpt_path = f"result/joint/joint-{epoch + 1}.pth" + torch.save({ + "epoch": epoch + 1, + "vq_state": model_vq.state_dict(), + "mg_state": model_mg.state_dict(), + "optim_state":optimizer.state_dict(), + }, ckpt_path) + tqdm.write(f" 检查点已保存: {ckpt_path}") + + val_loss = validate( + model_vq, model_mg, dataloader_val, tile_dict, epoch + 1 + ) + tqdm.write( + f"[Validate] Epoch {epoch + 1:4d} | Val Loss {val_loss:.5f}" + ) + # 恢复训练模式 + model_vq.train() + model_mg.train() + + print("训练结束。") + torch.save({ + "epoch": start_epoch + args.epochs, + "vq_state": model_vq.state_dict(), + "mg_state": model_mg.state_dict(), + }, "result/joint/joint_final.pth") + + +# --------------------------------------------------------------------------- +if __name__ == "__main__": + torch.set_num_threads(4) + train() diff --git a/ginka/vqvae/__init__.py b/ginka/vqvae/__init__.py new file mode 100644 index 0000000..e01c7b9 --- /dev/null +++ b/ginka/vqvae/__init__.py @@ -0,0 +1 @@ +from .model import GinkaVQVAE diff --git a/ginka/vqvae/model.py b/ginka/vqvae/model.py new file mode 100644 index 0000000..79ee4c6 --- /dev/null +++ b/ginka/vqvae/model.py @@ -0,0 +1,186 @@ +import torch +import torch.nn as nn +from .quantize import VectorQuantizer +from typing import Tuple + + +class GinkaVQVAE(nn.Module): + """ + VQ-VAE 风格地图编码器。 + + 将一张完整的地图([B, H*W] 整数 tile ID 序列)编码为 L 个离散码字, + 输出 z [B, L, d_z] 作为 MaskGIT 模型的生成条件。 + + 架构: + tile embedding + 位置编码 + → L 个可学习 summary token(拼接到序列头部) + → Transformer Encoder(Pre-LN,自注意力) + → 取前 L 个输出 + → 线性投影到 d_z + → VectorQuantizer(直通估计 + 熵最大化正则) + + 设计约束: + - 参数量目标 < 1M + - 不含解码器,z 的语义由 MaskGIT 端的交叉熵损失间接约束 + - z 定位为风格/多样性控制信号,而非结构重建指导 + """ + + def __init__( + self, + num_classes: int = 16, + L: int = 2, + K: int = 16, + d_z: int = 64, + d_model: int = 128, + nhead: int = 4, + num_layers: int = 2, + dim_ff: int = 256, + map_size: int = 13 * 13, + beta: float = 0.25, + gamma: float = 0.1, + vq_temp: float = 1.0, + ): + """ + Args: + num_classes: tile 类别数(含 MASK token) + L: 码字序列长度,即 z 的序列维度 + K: codebook 大小(码字总数) + d_z: 码字嵌入维度 + d_model: Transformer 内部维度 + nhead: 注意力头数 + num_layers: Transformer 层数 + dim_ff: 前馈网络隐层维度 + map_size: 地图 token 总数(H * W) + beta: 承诺损失权重 + gamma: 熵正则损失权重 + vq_temp: VQ 软分配 softmax 温度 + """ + super().__init__() + self.L = L + self.K = K + self.d_z = d_z + self.beta = beta + self.gamma = gamma + + # Tile 嵌入 + self.tile_embedding = nn.Embedding(num_classes, d_model) + + # 地图位置编码(仅覆盖 map_size 个位置,不含 summary token) + self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model) * 0.02) + + # L 个可学习 summary token,拼接到序列头部 + self.summary_tokens = nn.Parameter(torch.randn(1, L, d_model) * 0.02) + + # Pre-LN Transformer Encoder(训练更稳定) + encoder_layer = nn.TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_ff, + batch_first=True, + activation='gelu', + norm_first=True, # Pre-LN + dropout=0.0, + ) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + + # 将 Transformer 输出投影到 codebook 维度 d_z + self.proj = nn.Sequential( + nn.Linear(d_model, d_z), + nn.LayerNorm(d_z), + ) + + # 向量量化层 + self.vq = VectorQuantizer(K=K, d_z=d_z, temp=vq_temp) + + def encode(self, map: torch.Tensor) -> torch.Tensor: + """ + 将地图编码为量化前的连续向量序列。 + + Args: + map: [B, H*W] 整数 tile ID + + Returns: + z_e: [B, L, d_z] 量化前的编码向量 + """ + B = map.shape[0] + + x = self.tile_embedding(map) # [B, H*W, d_model] + x = x + self.pos_embedding # [B, H*W, d_model] + + summary = self.summary_tokens.expand(B, -1, -1) # [B, L, d_model] + x = torch.cat([summary, x], dim=1) # [B, L+H*W, d_model] + + x = self.transformer(x) # [B, L+H*W, d_model] + + z_e = self.proj(x[:, :self.L]) # [B, L, d_z] + return z_e + + def forward(self, map: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + 完整前向传播:编码 → 量化 → 计算损失。 + + Args: + map: [B, H*W] 整数 tile ID(训练时传入完整真实地图) + + Returns: + z_q: [B, L, d_z] 量化后的 z(含直通梯度),供 MaskGIT 使用 + indices: [B, L] 每个位置对应的码字索引 + vq_loss: scalar VQ 总损失 = beta * commit_loss + gamma * entropy_loss + """ + z_e = self.encode(map) + z_q, indices, commit_loss, entropy_loss = self.vq(z_e) + + vq_loss = self.beta * commit_loss + self.gamma * entropy_loss + return z_q, indices, vq_loss + + def sample(self, B: int, device: torch.device) -> torch.Tensor: + """ + 推理阶段:从 codebook 中随机均匀采样 L 个码字。 + + Args: + B: batch size + device: 目标设备 + + Returns: + z: [B, L, d_z] + """ + indices = torch.randint(0, self.K, (B, self.L), device=device) + z = self.vq.codebook(indices) # [B, L, d_z] + return z + + +if __name__ == "__main__": + device = torch.device("cpu") + + model = GinkaVQVAE( + num_classes=16, + L=2, + K=16, + d_z=64, + d_model=128, + nhead=4, + num_layers=2, + dim_ff=256, + map_size=13 * 13, + ).to(device) + + total_params = sum(p.numel() for p in model.parameters()) + print(f"总参数量: {total_params:,} ({total_params / 1e6:.3f}M)") + + # 分模块参数统计 + for name, module in model.named_children(): + n = sum(p.numel() for p in module.parameters()) + print(f" {name}: {n:,}") + + # 前向传播测试 + map_input = torch.randint(0, 15, (4, 13 * 13)).to(device) # [B=4, 169] + + z_q, indices, vq_loss = model(map_input) + + print(f"\nz_q shape: {z_q.shape}") # [4, 2, 64] + print(f"indices shape:{indices.shape}") # [4, 2] + print(f"vq_loss: {vq_loss.item():.4f}") + + # 推理采样测试 + z_sample = model.sample(B=4, device=device) + print(f"sample shape: {z_sample.shape}") # [4, 2, 64] diff --git a/ginka/vqvae/quantize.py b/ginka/vqvae/quantize.py new file mode 100644 index 0000000..6f23e11 --- /dev/null +++ b/ginka/vqvae/quantize.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Tuple + + +class VectorQuantizer(nn.Module): + """ + 向量量化层(Vector Quantization)。 + + 将连续的编码向量序列映射到离散的 codebook 码字索引, + 并通过直通估计(Straight-Through Estimator)保持梯度流。 + + 均匀分布正则化采用软分配熵最大化方案: + 通过对距离做 softmax 得到软分配概率,计算平均码字使用率的熵, + 最小化负熵以鼓励所有码字被均等使用。 + """ + + def __init__(self, K: int, d_z: int, temp: float = 1.0): + """ + Args: + K: codebook 大小(码字数量) + d_z: 码字嵌入维度 + temp: 软分配 softmax 温度,越小越接近 hard assignment + """ + super().__init__() + self.K = K + self.d_z = d_z + self.temp = temp + + self.codebook = nn.Embedding(K, d_z) + nn.init.uniform_(self.codebook.weight, -1.0 / K, 1.0 / K) + + def forward(self, z_e: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + z_e: [B, L, d_z] 编码器输出的连续向量序列 + + Returns: + z_q_st: [B, L, d_z] 量化后向量(直通梯度) + indices: [B, L] 每个位置对应的码字索引 + commit_loss: scalar 承诺损失 ||z_e - sg(z_q)||^2 + entropy_loss: scalar 负熵损失(最小化 = 最大化码字使用均匀度) + """ + B, L, d_z = z_e.shape + + # 展平到 [B*L, d_z] + z_flat = z_e.reshape(B * L, d_z) + + codebook_w = self.codebook.weight # [K, d_z] + + # 计算 L2 距离:||z_e - e_k||^2 = ||z_e||^2 + ||e_k||^2 - 2 * z_e · e_k + # distances: [B*L, K] + distances = ( + (z_flat ** 2).sum(dim=1, keepdim=True) # [B*L, 1] + + (codebook_w ** 2).sum(dim=1) # [K] + - 2.0 * z_flat @ codebook_w.t() # [B*L, K] + ) + + # Hard assignment:取最近码字索引 + indices = distances.argmin(dim=1) # [B*L] + + # 量化向量 + z_q_flat = self.codebook(indices) # [B*L, d_z] + z_q = z_q_flat.reshape(B, L, d_z) + + # 直通估计:前向传 z_q,反向传 z_e 的梯度 + z_q_st = z_e + (z_q - z_e).detach() + + # 承诺损失:拉近编码向量与其对应的码字(仅更新编码器) + commit_loss = F.mse_loss(z_e, z_q.detach()) + + # 熵最大化正则:通过软分配计算平均码字使用率,最小化负熵 + # soft_assign: [B*L, K],对距离做 softmax(距离越小,概率越大) + soft_assign = F.softmax(-distances / self.temp, dim=1) + avg_assign = soft_assign.mean(dim=0) # [K],平均码字使用率 + # entropy_loss = -H(p) = sum(p * log(p)),最小化即最大化熵 + entropy_loss = (avg_assign * torch.log(avg_assign + 1e-10)).sum() + + indices = indices.reshape(B, L) + return z_q_st, indices, commit_loss, entropy_loss