diff --git a/docs/vqvae-split-channel-design.md b/docs/vqvae-split-channel-design.md new file mode 100644 index 0000000..d3d483d --- /dev/null +++ b/docs/vqvae-split-channel-design.md @@ -0,0 +1,322 @@ +# VQ 编码器三通道分拆预训练设计文档 + +## 背景与问题诊断 + +### 核心问题 + +在当前 VQ-VAE + MaskGIT 联合训练方案(以及方案 A 闭环约束、方案 D 全图预训练)中,一个根本性的类别不均衡问题始终未被解决: + +地图中约 70–85% 的格子为墙壁(tile=1)和空地(tile=0),而怪物(9)、门(2)、入口(10)、宝石(4/5/6)、钥匙(3)、药水(7)、道具(8)等功能性 tile 仅占少数。当编码器以完整地图为输入、重建损失对所有位置平等计算时: + +- 编码器的梯度信号被墙壁/空地主导,特征向量主要编码了空间结构,而非功能性内容; +- 解码(或生成)时,怪物/门/资源类 tile 的召回率远低于墙壁; +- 即使方案 D 进行了全图预训练,由于损失仍是全位置等权 Focal Loss,预训练阶段本身就强化了这种偏向; +- **已将损失由 CE 替换为 Focal Loss,但召回率仍无明显改善**——说明问题的根源在于梯度信号的来源比例,而非损失函数本身的形式; +- 方案 A 的闭环一致性约束在 z 本身质量不足时,仅能强化已有的偏向,难以从根本上改善。 + +### 改进目标 + +通过**按语义层次将地图拆分为三个独立通道**,采用**累积式输入 + 通道专属损失**的设计:每个通道的编码器输入包含当前通道及所有低等级 tile(提供空间上下文),而解码头的损失**仅在本通道专属 tile 的位置计算,低等级 tile 的损失权重为 0**。这样既保证了编码器能在有意义的空间结构中学习(避免功能性 tile 沦为孤立散点),又迫使解码器必须正确预测本通道 tile 才能降低损失——无法靠拟合高频的墙壁/空地来刷低损失值。 + +--- + +## Tile 类型划分 + +根据 `data/src/shared.ts` 中的定义,将 15 个有效 tile 类型(不含 MASK=15)按游戏语义分为三个通道: + +| 通道 | 编码器输入(切片内容) | 解码损失计算范围 | 语义含义 | +| ------ | ---------------------------------------------------- | ------------------ | ------------------------------------ | +| 通道 1 | floor(0) + wall(1) | {1}(仅墙壁) | 空间骨架(地形结构) | +| 通道 2 | floor(0) + wall(1) + door(2) + mob(9) + entrance(10) | {2, 9, 10} | 关卡门控(交互元素,决定路径可达性) | +| 通道 3 | 完整地图(所有 tile) | {3, 4, 5, 6, 7, 8} | 收集资源(奖励与道具) | + +**通道划分的关键设计原则**: + +- 通道 1 仅包含 floor 和 wall,编码器集中学习地图的空间骨架结构; +- 通道 2 的输入**保留墙壁与空地作为空间上下文**,在此基础上叠加 door/mob/entrance,确保编码器能感知功能性 tile 在地图中的位置关系,而非将其视为无空间依附的孤立散点; +- 通道 3 的输入为**完整地图**,编码器在包含骨架与关卡结构的完整上下文中学习资源 tile 的空间分布; +- 预训练时每个解码头的损失**仅在本通道专属 tile 的位置计算,低等级 tile(floor、wall 及前级通道的 tile)损失权重为 0**——迫使编码器必须通过正确预测本通道 tile 来降低损失,无法靠拟合高频背景来规避优化压力。 + +--- + +## 整体架构 + +### 预训练阶段 + +``` +完整地图 [B, H*W] + │ + ├──► 切片 1:floor(0)+wall(1)(这已是全部内容,无需替换) + │ │ + │ ▼ + │ Encoder_1 → VQ_1 → z_1 [B, L_1, d_z] + │ │ + │ ▼ + │ DecodeHead_1 → logits [B, H*W, C] + │ │ + │ ▼ + │ Loss_1:仅在 tile∈{1} 的位置计算 Focal Loss(仅墙壁,空地权重为 0) + │ + ├──► 切片 2:保留 floor(0)+wall(1)+door(2)+mob(9)+entry(10),其余→floor(0) + │ │ + │ ▼ + │ Encoder_2 → VQ_2 → z_2 [B, L_2, d_z] + │ │ + │ ▼ + │ DecodeHead_2 → logits [B, H*W, C] + │ │ + │ ▼ + │ Loss_2:仅在 tile∈{2,9,10} 的位置计算 Focal Loss(floor/wall 权重为 0) + │ + └──► 切片 3:完整地图(所有 tile,无需替换) + │ + ▼ + Encoder_3 → VQ_3 → z_3 [B, L_3, d_z] + │ + ▼ + DecodeHead_3 → logits [B, H*W, C] + │ + ▼ + Loss_3:仅在 tile∈{3,4,5,6,7,8} 的位置计算 Focal Loss(其余 tile 权重为 0) +``` + +三路编码器**相互独立预训练**,每路的预训练损失: + +$$\mathcal{L}_{pretrain}^{(k)} = \mathcal{L}_{FL}^{(k)} + \beta \cdot \mathcal{L}_{commit}^{(k)} + \gamma \cdot \mathcal{L}_{uniform}^{(k)}$$ + +其中 $\mathcal{L}_{FL}^{(k)}$ 为通道 $k$ 的通道专属掩码 Focal Loss(见下节)。 + +### 联合训练阶段 + +``` +完整地图 ──► 三路切片 ──► [Enc_1, Enc_2, Enc_3] ──► [z_1, z_2, z_3] + │ + z = Concat([z_1, z_2, z_3], dim=1) + │ + ▼ +掩码地图 + z ──► MaskGIT (Cross-Attention) ──► 预测 logits ──► Focal Loss +``` + +联合训练总损失(不含预训练解码头): + +$$\mathcal{L}_{joint} = \mathcal{L}_{FL}^{MaskGIT} + \sum_{k=1}^{3} \left( \beta \cdot \mathcal{L}_{commit}^{(k)} + \gamma \cdot \mathcal{L}_{uniform}^{(k)} \right)$$ + +--- + +## 通道专属掩码 Focal Loss + +这是方案的核心机制。对于通道 $k$,设其专属 tile 集合为 $\mathcal{T}_k$(不含低等级 tile),则损失计算为: + +$$\mathcal{L}_{FL}^{(k)} = \frac{\sum_{i=1}^{H \times W} \mathbf{1}[y_i \in \mathcal{T}_k] \cdot \text{FL}(\hat{y}_i, y_i)}{\sum_{i=1}^{H \times W} \mathbf{1}[y_i \in \mathcal{T}_k] + \epsilon}$$ + +其中 $y_i$ 为真实 tile 类型,$\hat{y}_i$ 为解码头输出的 logits,$\text{FL}$ 为 Focal Loss。 + +**实现方式**(PyTorch 伪代码): + +```python +# 通道 2 示例(输入切片已包含 floor+wall 作为上下文) +CHANNEL_2_TILES = {2, 9, 10} # door, mob, entrance + +# target: 完整地图 ground truth,[B, H*W] +# logits: DecodeHead_2 输出,[B, H*W, num_classes] + +mask = torch.zeros_like(target, dtype=torch.bool) +for t in CHANNEL_2_TILES: + mask |= (target == t) # [B, H*W] bool,仅通道 2 专属 tile 的位置为 True + +# focal_loss: reduction='none',返回 [B * H*W] +fl = focal_loss( + logits.view(-1, num_classes), + target.view(-1), +) # [B * H*W] +fl = fl.view(B, -1) # [B, H*W] + +loss_ch2 = (fl * mask).sum() / (mask.sum() + 1e-6) +``` + +**为什么输入包含墙壁、但损失不计算墙壁**:通道 2 的切片中保留了 floor+wall,是为了给编码器提供空间结构上下文,使门/怪/入口的位置有意义(否则孤立散点难以形成有效表示)。但损失仅在 `{2, 9, 10}` 位置计算,确保梯度信号完全来自这三类 tile——编码器如果只靠拟合高频的墙壁/空地,解码头在 `{2, 9, 10}` 位置的损失无法降低,从而被迫学习功能性 tile 的空间分布。 + +--- + +## 切片构造规则 + +| 通道 | 切片中保留的 tile | 其余位置替换为 | 解码损失计算的位置(专属 tile 集合 $\mathcal{T}_k$) | +| ---- | ----------------- | -------------- | ------------------------------------------------------- | +| 1 | 0, 1 | —(无需替换) | {1}(仅墙壁;空地是墙壁的补集,能预测墙壁即能区分空地) | +| 2 | 0, 1, 2, 9, 10 | 0(floor) | {2, 9, 10}(floor/wall 损失权重为 0) | +| 3 | 全部(完整地图) | —(无需替换) | {3, 4, 5, 6, 7, 8}(其余 tile 损失权重为 0) | + +--- + +## 编码器架构设计 + +### 三路独立编码器 + +三个编码器均复用现有的 `GinkaVQVAE` 类,配置略有差异: + +| 参数 | Encoder_1(结构骨架) | Encoder_2(关卡门控) | Encoder_3(收集资源) | +| --------------- | --------------------- | --------------------- | --------------------- | +| `L`(码字数) | 2 | 2 | 2 | +| `K`(码本大小) | 16 | 16 | 16 | +| `d_z` | 64 | 64 | 64 | +| `d_model` | 128 | 64 | 64 | +| `num_layers` | 2 | 2 | 2 | + +- Encoder_1 处理高频 tile,适当加大 `d_model`; +- Encoder_2 和 Encoder_3 的功能性 tile 稀疏,信息量较小,可使用较小的 `d_model`; +- 三路 `d_z` 保持一致,以便拼接后维度齐整; +- 总参数量估算:Encoder_1 ~400K + Encoder_2 ~150K + Encoder_3 ~150K ≈ **700K**,在 1M 预算内。 + +> 考虑到训练集数量较少,可以考虑适当降低 K 和 L 的值,避免模型死记硬背,也可以防止训练集没有覆盖全部所有情况。1M 的参数量仅做估计,可以先尝试较大的参数量,如出现过拟合再降低。 + +### 联合训练时的 z 拼接 + +$$z = \text{Concat}([z_1, z_2, z_3], \dim=1) \in \mathbb{R}^{B \times (L_1+L_2+L_3) \times d_z}$$ + +以各通道 `L=2` 为例,总 memory 长度为 6,与当前 MaskGIT Cross-Attention 的 memory 规模(原来 L=2)相比略有增加,但绝对长度仍很小,不影响计算效率。 + +--- + +## 预训练流程 + +### 解码头复用 + +预训练时三路各自使用一个 `VQDecodeHead` 实例(现有类,`num_classes=16`),预训练结束后整体丢弃。解码头参数不迁移到联合训练阶段。 + +### 训练脚本 + +新增 `ginka/train_pretrain_split.py`(独立于现有的 `train_pretrain.py`): + +```python +# 伪代码结构 +for epoch in ...: + for batch in dataloader: + raw_map = batch["raw_map"] # [B, H*W] 完整地图 + + # ─── 通道 1 ─── + slice1 = make_slice(raw_map, keep={0, 1}) # floor+wall,切片即完整输入 + z_q1, z_e1, _, vq_loss1, *_ = enc1(slice1) + logits1 = head1(z_q1) + fl1 = masked_focal(logits1, raw_map, tile_set={1}) # 仅对 wall 计损失 + loss1 = fl1 + vq_loss1 + + # ─── 通道 2 ─── + slice2 = make_slice(raw_map, keep={0, 1, 2, 9, 10}) # 保留 wall/floor 作为上下文 + z_q2, z_e2, _, vq_loss2, *_ = enc2(slice2) + logits2 = head2(z_q2) + fl2 = masked_focal(logits2, raw_map, tile_set={2, 9, 10}) # 仅对专属 tile 计损失 + loss2 = fl2 + vq_loss2 + + # ─── 通道 3 ─── + slice3 = raw_map # 完整地图,无需切片 + z_q3, z_e3, _, vq_loss3, *_ = enc3(slice3) + logits3 = head3(z_q3) + fl3 = masked_focal(logits3, raw_map, tile_set={3, 4, 5, 6, 7, 8}) # 仅对专属 tile 计损失 + loss3 = fl3 + vq_loss3 + + total = loss1 + loss2 + loss3 + total.backward() + optimizer.step() +``` + +三路编码器可以**同步训练**(同一 optimizer),也可以分别独立训练——独立训练更灵活,可以对收敛速度差异大的通道单独调参。 + +### 预训练监控指标 + +| 指标 | 说明 | +| ----------------------- | ---------------------------------------------------------------- | +| 通道 1 wall 位置准确率 | Encoder_1 能否正确重建墙壁分布 | +| 通道 2 功能 tile 召回率 | Encoder_2 对 door/mob/entrance 各类的召回,应 > 50%(稀疏 tile) | +| 通道 3 资源 tile 召回率 | Encoder_3 对各资源类的召回 | +| codebook 使用熵(各路) | 各通道 codebook 是否均匀使用,避免 collapse | + +由于通道 2/3 的 tile 在每张地图中数量极少(典型地图中 door/mob/entrance 合计约 10–20 格,资源合计约 10–15 格),召回率指标比准确率更有意义。 + +--- + +## 联合训练流程 + +### 三阶段训练 + +| 阶段 | 模型状态 | 目标 | 建议轮数 | +| -------------------- | ----------------------------------------- | ------------------------------------------------ | ------------ | +| 阶段 0:分通道预训练 | 三路 Encoder + 三路 DecodeHead | 各通道 Focal Loss 收敛,功能 tile 召回率达到目标 | 30–60 epoch | +| 阶段 1:冻结热身 | 三路 Encoder 冻结 + MaskGIT 全参训练 | MaskGIT 适应三路 z 的联合分布 | 20–40 epoch | +| 阶段 2:完整联合训练 | 全部参数解冻,Encoder 使用较小 LR(×0.1) | 端到端联合优化 | 正常训练轮数 | + +### 联合训练数据集 + +`GinkaJointDataset` 需扩展为同时提供三路切片: + +```python +# 返回字典新增字段 +{ + "raw_map": ..., # [H*W] 完整地图(VQ 编码器输入) + "slice1": ..., # [H*W] 通道 1 切片 + "slice2": ..., # [H*W] 通道 2 切片 + "slice3": ..., # [H*W] 通道 3 切片 + "masked_map": ..., # [H*W] MaskGIT 输入(掩码后地图) + "target_map": ..., # [H*W] MaskGIT CE ground truth +} +``` + +--- + +## 推理时的 z 采样 + +推理时三路编码器均独立采样,无需用户输入: + +```python +# 完全随机生成 +z1 = enc1.sample(B, device) # [B, L1, d_z] +z2 = enc2.sample(B, device) # [B, L2, d_z] +z3 = enc3.sample(B, device) # [B, L3, d_z] +z = torch.cat([z1, z2, z3], dim=1) # [B, L1+L2+L3, d_z] +``` + +**分通道条件控制**(可选扩展): + +| 场景 | 通道 1 z 来源 | 通道 2 z 来源 | 通道 3 z 来源 | +| ------------ | ------------- | ------------- | ------------- | +| 完全随机生成 | 随机采样 | 随机采样 | 随机采样 | +| 指定墙壁布局 | 用户地图编码 | 随机采样 | 随机采样 | +| 指定关卡结构 | 随机采样 | 参考图编码 | 随机采样 | +| 风格迁移 | 参考图编码 | 参考图编码 | 随机采样 | + +--- + +## 与现有方案的关系 + +| 方案 | 与本方案的关系 | +| ------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| 方案 A | 兼容。由于通道 2/3 的编码器输入包含墙壁上下文(而非孤立散点),其 z 具备稳定的空间语义,一致性约束可同样作用于三个通道:通道 1 约束骨架结构,通道 2 约束关卡元素分布,通道 3 约束资源分布 | +| 方案 D | 本方案**替代并强化**方案 D 的预训练思路:方案 D 是全图等权 Focal Loss 预训练,本方案通过累积式输入 + 通道专属损失从根本上解决了类别不均衡问题 | +| 方案 C | 兼容。多阶段生成(先墙后门后资源)可以将通道 1/2/3 的 z 分别作为各阶段的生成条件 | + +--- + +## 超参数建议 + +| 参数 | 建议初始值 | 备注 | +| -------------------------- | ---------- | ---------------------------------------- | +| 各通道 `L`(码字数) | 2 | 三路合计 6,可视效果适当扩大 | +| 各通道 `K`(码本大小) | 16 | 通道 2/3 可减小到 8(tile 种类少) | +| `d_z` | 64 | 三路保持一致,便于拼接 | +| `β`(commit loss 权重) | 0.25 | 同现有配置 | +| `γ`(uniform loss 权重) | 0.1 | 通道 2/3 码本小,可适当增大到 0.2 | +| 预训练 epoch | 30–60 | 以功能 tile 召回率达标为准,不以轮数为限 | +| 联合训练 Encoder LR 缩放比 | 0.1 | 阶段 2 解冻后使用较小 LR 微调 | +| z dropout 概率(联合训练) | 0.1–0.2 | 三路 z 各自独立 dropout | + +--- + +## 实施步骤 + +- [ ] 在 `ginka/dataset.py` 中实现 `make_slice(map, keep_set)` 辅助函数,生成三路切片 +- [ ] 扩展 `GinkaJointDataset.__getitem__`,新增 `slice1/slice2/slice3` 字段 +- [ ] 在 `ginka/vqvae/model.py` 中确认 `GinkaVQVAE` 可独立实例化三次(无全局状态) +- [ ] 实现 `masked_focal(logits, target, tile_set)` 工具函数(`ginka/utils.py`) +- [ ] 新增 `ginka/train_pretrain_split.py` 预训练脚本(支持三路同步或分路训练) +- [ ] 修改 `ginka/train_vq.py`(联合训练脚本),支持加载三路编码器权重并拼接 z +- [ ] 修改 `GinkaMaskGIT.forward()` 以接受 `[B, L1+L2+L3, d_z]` 的拼接 z(Cross-Attention memory) +- [ ] 添加联合训练监控:各通道 codebook 使用熵、功能 tile 召回率 diff --git a/ginka/dataset.py b/ginka/dataset.py index d5fe027..ca8788f 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -459,8 +459,12 @@ class GinkaVQDataset(Dataset): struct_cond = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer]) + raw_t = torch.LongTensor(raw_flat) return { - "raw_map": torch.LongTensor(raw_flat), # VQ-VAE 编码器输入 + "raw_map": raw_t, # VQ-VAE 编码器输入 + "slice1": make_slice(raw_t, {0, 1}), # 通道 1:floor+wall + "slice2": make_slice(raw_t, {0, 1, 2, 9, 10}),# 通道 2:floor+wall+门+怪+入口 + "slice3": raw_t.clone(), # 通道 3:完整地图 "masked_map": torch.LongTensor(masked_np), # MaskGIT 输入 "target_map": torch.LongTensor(raw_flat.copy()), # CE loss ground truth "subset": subset, # 调试/统计用 @@ -468,6 +472,78 @@ class GinkaVQDataset(Dataset): } +# --------------------------------------------------------------------------- +# make_slice:按保留集合切割地图,其余位置替换为 floor(0) +# --------------------------------------------------------------------------- + +def make_slice(map_flat: torch.Tensor, keep_set: set) -> torch.Tensor: + """ + 从完整地图中只保留 keep_set 中的 tile 类型,其余位置替换为 floor(0)。 + + Args: + map_flat: LongTensor [H*W] 完整地图 tile ID 序列 + keep_set: set of int 需要保留的 tile 类型集合 + + Returns: + LongTensor [H*W] 切片后的地图(非保留 tile 位置值为 0) + """ + out = map_flat.clone() + mask = torch.zeros_like(out, dtype=torch.bool) + for t in keep_set: + mask |= (out == t) + out[~mask] = 0 + return out + + +# --------------------------------------------------------------------------- +# GinkaSplitDataset:三通道分拆预训练专用数据集 +# --------------------------------------------------------------------------- + +class GinkaSplitDataset(Dataset): + """ + 三通道分拆预训练(方案 B)专用数据集。 + + 每个样本只提供完整地图及其三路切片,不做 MaskGIT 掩码处理。 + 切片按累积式设计: + slice1 = floor(0) + wall(1) + slice2 = floor(0) + wall(1) + door(2) + mob(9) + entrance(10) + slice3 = 完整地图(所有 tile) + + 返回 dict: + raw_map: LongTensor [H*W] 完整原始地图 + slice1: LongTensor [H*W] 通道 1 切片(floor+wall) + slice2: LongTensor [H*W] 通道 2 切片(floor+wall+门+怪+入口) + slice3: LongTensor [H*W] 通道 3 切片(完整地图) + """ + + def __init__(self, data_path: str): + self.data = load_data(data_path) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + item = self.data[idx] + arr = np.array(item['map'], dtype=np.int64) # [H, W] + + # 随机旋转 / 翻转数据增强 + if np.random.rand() > 0.5: + k = np.random.randint(1, 4) + arr = np.rot90(arr, k).copy() + if np.random.rand() > 0.5: + arr = np.fliplr(arr).copy() + if np.random.rand() > 0.5: + arr = np.flipud(arr).copy() + + raw = torch.LongTensor(arr.reshape(-1)) # [H*W] + return { + "raw_map": raw, + "slice1": make_slice(raw, {0, 1}), + "slice2": make_slice(raw, {0, 1, 2, 9, 10}), + "slice3": raw.clone(), + } + + if __name__ == "__main__": import os data_path = os.path.join(os.path.dirname(__file__), '..', 'data', 'ginka-dataset.json') diff --git a/ginka/maskGIT/model.py b/ginka/maskGIT/model.py index 3d8f2f8..7b26755 100644 --- a/ginka/maskGIT/model.py +++ b/ginka/maskGIT/model.py @@ -83,12 +83,13 @@ class GinkaMaskGIT(nn.Module): ) -> torch.Tensor: """ Args: - map: [B, H*W] 掩码后的地图 token 序列(MASK token = 15) - z: [B, L, d_z] VQ-VAE 量化后的离散隐变量 - struct_cond: [B, 4] 结构标签 LongTensor,顺序为 - [cond_sym, cond_room, cond_branch, cond_outer]; - 为 None 时等价于全 null(无条件模式) - dropout_struct: bool 强制将所有结构标签替换为 null(推理时无条件生成) + map: [B, H*W] 掩码后的地图 token 序列(MASK token = 15) + z: [B, L_total, d_z] VQ-VAE 量化后的离散隐变量; + 方案 B 中 L_total = L1+L2+L3(三路 z 拼接) + struct_cond: [B, 4] 结构标签 LongTensor,顺序为 + [cond_sym, cond_room, cond_branch, cond_outer]; + 为 None 时等价于全 null(无条件模式) + dropout_struct: bool 强制将所有结构标签替换为 null(推理时无条件生成) Returns: logits: [B, H*W, num_classes] diff --git a/ginka/train_pretrain_split.py b/ginka/train_pretrain_split.py new file mode 100644 index 0000000..9cad109 --- /dev/null +++ b/ginka/train_pretrain_split.py @@ -0,0 +1,363 @@ +""" +三通道分拆预训练脚本(方案 B) + +三路编码器各自负责一个语义通道: + 通道 1:空间骨架(floor+wall),损失仅计算 wall(1) 位置 + 通道 2:关卡门控(floor+wall+door+mob+entrance),损失仅计算 {2,9,10} 位置 + 通道 3:收集资源(完整地图),损失仅计算 {3,4,5,6,7,8} 位置 + +预训练完成后保存各通道编码器权重(不含解码头), +供联合训练脚本 train_vq.py 加载并拼接 z。 + +用法示例: + python -m ginka.train_pretrain_split + python -m ginka.train_pretrain_split --resume True --state result/pretrain_split/split-10.pth + # 预训练完成后指定权重路径启动联合训练: + python -m ginka.train_vq --pretrain_split result/pretrain_split/split_final.pth +""" + +import argparse +import os +import sys +from datetime import datetime + +import numpy as np +import torch +import torch.optim as optim +from torch.utils.data import DataLoader +from tqdm import tqdm + +from .vqvae.model import GinkaVQVAE, VQDecodeHead +from .dataset import GinkaSplitDataset +from .utils import masked_focal + +# --------------------------------------------------------------------------- +# 超参数 +# --------------------------------------------------------------------------- +BATCH_SIZE = 64 +NUM_CLASSES = 16 +MAP_SIZE = 13 * 13 +FOCAL_GAMMA = 2.0 + +# 通道 1:空间骨架(floor+wall) +CH1_KEEP = {0, 1} # 编码器输入保留的 tile +CH1_LOSS = {1} # 损失计算范围(仅 wall) +CH1_D_MODEL = 128 +CH1_NHEAD = 4 + +# 通道 2:关卡门控 +CH2_KEEP = {0, 1, 2, 9, 10} +CH2_LOSS = {2, 9, 10} +CH2_D_MODEL = 64 +CH2_NHEAD = 4 + +# 通道 3:收集资源 +CH3_KEEP = None # 完整地图,无需切片 +CH3_LOSS = {3, 4, 5, 6, 7, 8} +CH3_D_MODEL = 64 +CH3_NHEAD = 4 + +# 三路共用的 VQ 超参 +VQ_L = 2 +VQ_K = 16 +VQ_D_Z = 64 +VQ_LAYERS = 2 +VQ_DIM_FF = 256 +VQ_BETA = 0.25 # commit loss 权重 +VQ_GAMMA = 0.1 # entropy loss 权重 + +# 解码头超参(三路共用相同规格) +DH_NHEAD = 4 +DH_DIM_FF = 256 +DH_LAYERS = 2 + +# --------------------------------------------------------------------------- +# 设备 +# --------------------------------------------------------------------------- +device = torch.device( + "cuda:1" if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() + else "cpu" +) + +os.makedirs("result/pretrain_split", exist_ok=True) + +disable_tqdm = not sys.stdout.isatty() + +# --------------------------------------------------------------------------- +# 参数解析 +# --------------------------------------------------------------------------- +def parse_arguments(): + parser = argparse.ArgumentParser(description="三通道分拆 VQ 编码器预训练(方案 B)") + parser.add_argument("--resume", type=bool, default=False) + parser.add_argument("--state", type=str, default="result/pretrain_split/split-10.pth", + help="续训时加载的检查点路径") + parser.add_argument("--train", type=str, default="ginka-dataset.json") + parser.add_argument("--validate", type=str, default="ginka-eval.json") + parser.add_argument("--epochs", type=int, default=60) + parser.add_argument("--checkpoint", type=int, default=5, + help="每隔多少 epoch 保存检查点并输出验证指标") + parser.add_argument("--load_optim", type=bool, default=True) + return parser.parse_args() + +# --------------------------------------------------------------------------- +# 验证:各通道专属 tile 召回率 + codebook 使用熵 +# --------------------------------------------------------------------------- +@torch.no_grad() +def validate( + enc1, enc2, enc3, + head1, head2, head3, + dataloader_val: DataLoader, +) -> dict: + for m in [enc1, enc2, enc3, head1, head2, head3]: + m.eval() + + # 每类 tile 的 tp / gt 计数 + ch1_tp, ch1_gt = 0, 0 # wall(1) + ch2_tp = {t: 0 for t in CH2_LOSS} # {2,9,10} + ch2_gt = {t: 0 for t in CH2_LOSS} + ch3_tp = {t: 0 for t in CH3_LOSS} # {3,4,5,6,7,8} + ch3_gt = {t: 0 for t in CH3_LOSS} + + # codebook 使用频次(用于熵估算) + codebook_counts = [ + torch.zeros(VQ_K, dtype=torch.long), # 通道 1 + torch.zeros(VQ_K, dtype=torch.long), # 通道 2 + torch.zeros(VQ_K, dtype=torch.long), # 通道 3 + ] + + for batch in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm): + raw_map = batch["raw_map"].to(device) + s1 = batch["slice1"].to(device) + s2 = batch["slice2"].to(device) + s3 = batch["slice3"].to(device) + + # 通道 1 + z_q1, _, idx1, _, _, _ = enc1(s1) + logits1 = head1(z_q1) + pred1 = logits1.argmax(dim=-1) # [B, H*W] + wall_m = (raw_map == 1) + ch1_tp += (pred1[wall_m] == 1).sum().item() + ch1_gt += wall_m.sum().item() + for code in idx1.view(-1).cpu(): + codebook_counts[0][code] += 1 + + # 通道 2 + z_q2, _, idx2, _, _, _ = enc2(s2) + logits2 = head2(z_q2) + pred2 = logits2.argmax(dim=-1) + for t in CH2_LOSS: + m = (raw_map == t) + ch2_tp[t] += (pred2[m] == t).sum().item() + ch2_gt[t] += m.sum().item() + for code in idx2.view(-1).cpu(): + codebook_counts[1][code] += 1 + + # 通道 3 + z_q3, _, idx3, _, _, _ = enc3(s3) + logits3 = head3(z_q3) + pred3 = logits3.argmax(dim=-1) + for t in CH3_LOSS: + m = (raw_map == t) + ch3_tp[t] += (pred3[m] == t).sum().item() + ch3_gt[t] += m.sum().item() + for code in idx3.view(-1).cpu(): + codebook_counts[2][code] += 1 + + def _entropy(counts): + """估算 codebook 使用熵(bits)。""" + counts = counts.float() + 1e-8 + p = counts / counts.sum() + return float(-(p * torch.log2(p)).sum()) + + metrics = { + "ch1_wall_recall": ch1_tp / max(ch1_gt, 1), + "ch2_recall": {t: ch2_tp[t] / max(ch2_gt[t], 1) for t in CH2_LOSS}, + "ch3_recall": {t: ch3_tp[t] / max(ch3_gt[t], 1) for t in CH3_LOSS}, + "codebook_entropy": [_entropy(c) for c in codebook_counts], + } + return metrics + +# --------------------------------------------------------------------------- +# 主训练函数 +# --------------------------------------------------------------------------- +def train(): + print(f"Using device: {device}") + args = parse_arguments() + + # ---- 三路编码器 ---- + enc1 = GinkaVQVAE( + num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_z=VQ_D_Z, + d_model=CH1_D_MODEL, nhead=CH1_NHEAD, num_layers=VQ_LAYERS, + dim_ff=VQ_DIM_FF, beta=VQ_BETA, gamma=VQ_GAMMA, + ).to(device) + + enc2 = GinkaVQVAE( + num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_z=VQ_D_Z, + d_model=CH2_D_MODEL, nhead=CH2_NHEAD, num_layers=VQ_LAYERS, + dim_ff=VQ_DIM_FF, beta=VQ_BETA, gamma=VQ_GAMMA, + ).to(device) + + enc3 = GinkaVQVAE( + num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_z=VQ_D_Z, + d_model=CH3_D_MODEL, nhead=CH3_NHEAD, num_layers=VQ_LAYERS, + dim_ff=VQ_DIM_FF, beta=VQ_BETA, gamma=VQ_GAMMA, + ).to(device) + + # ---- 三路解码头(预训练专用,训练后丢弃)---- + head1 = VQDecodeHead( + num_classes=NUM_CLASSES, d_z=VQ_D_Z, map_size=MAP_SIZE, + nhead=DH_NHEAD, dim_ff=DH_DIM_FF, num_layers=DH_LAYERS, + ).to(device) + + head2 = VQDecodeHead( + num_classes=NUM_CLASSES, d_z=VQ_D_Z, map_size=MAP_SIZE, + nhead=DH_NHEAD, dim_ff=DH_DIM_FF, num_layers=DH_LAYERS, + ).to(device) + + head3 = VQDecodeHead( + num_classes=NUM_CLASSES, d_z=VQ_D_Z, map_size=MAP_SIZE, + nhead=DH_NHEAD, dim_ff=DH_DIM_FF, num_layers=DH_LAYERS, + ).to(device) + + # ---- 优化器(三路同步训练) ---- + optimizer = optim.AdamW( + list(enc1.parameters()) + list(enc2.parameters()) + list(enc3.parameters()) + + list(head1.parameters()) + list(head2.parameters()) + list(head3.parameters()), + lr=1e-3, + weight_decay=1e-4, + ) + + start_epoch = 0 + + # ---- 续训 ---- + if args.resume: + ckpt = torch.load(args.state, map_location=device) + enc1.load_state_dict(ckpt["enc1"]) + enc2.load_state_dict(ckpt["enc2"]) + enc3.load_state_dict(ckpt["enc3"]) + head1.load_state_dict(ckpt["head1"]) + head2.load_state_dict(ckpt["head2"]) + head3.load_state_dict(ckpt["head3"]) + if args.load_optim and "optimizer" in ckpt: + optimizer.load_state_dict(ckpt["optimizer"]) + start_epoch = ckpt.get("epoch", 0) + print(f"Resumed from epoch {start_epoch}: {args.state}") + + # ---- 数据集 ---- + ds_train = GinkaSplitDataset(args.train) + ds_val = GinkaSplitDataset(args.validate) + dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True) + dl_val = DataLoader(ds_val, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True) + + print(f"训练集大小: {len(ds_train)},验证集大小: {len(ds_val)}") + + total_params = ( + sum(p.numel() for p in enc1.parameters()) + + sum(p.numel() for p in enc2.parameters()) + + sum(p.numel() for p in enc3.parameters()) + ) + print(f"编码器总参数量(三路): {total_params:,} ({total_params / 1e6:.3f}M)") + + # ---- 训练循环 ---- + for epoch in range(start_epoch, args.epochs): + for m in [enc1, enc2, enc3, head1, head2, head3]: + m.train() + + total_loss = 0.0 + ch_losses = [0.0, 0.0, 0.0] + + for batch in tqdm(dl_train, desc=f"Epoch {epoch + 1}/{args.epochs}", disable=disable_tqdm): + raw_map = batch["raw_map"].to(device) + s1 = batch["slice1"].to(device) + s2 = batch["slice2"].to(device) + s3 = batch["slice3"].to(device) + + optimizer.zero_grad() + + # ─── 通道 1 ─── + z_q1, _, _, vq_loss1, commit_loss1, entropy_loss1 = enc1(s1) + logits1 = head1(z_q1) # [B, H*W, C] + fl1 = masked_focal(logits1, raw_map, CH1_LOSS, gamma=FOCAL_GAMMA) + loss1 = fl1 + VQ_BETA * commit_loss1 + VQ_GAMMA * entropy_loss1 + + # ─── 通道 2 ─── + z_q2, _, _, vq_loss2, commit_loss2, entropy_loss2 = enc2(s2) + logits2 = head2(z_q2) + fl2 = masked_focal(logits2, raw_map, CH2_LOSS, gamma=FOCAL_GAMMA) + loss2 = fl2 + VQ_BETA * commit_loss2 + VQ_GAMMA * entropy_loss2 + + # ─── 通道 3 ─── + z_q3, _, _, vq_loss3, commit_loss3, entropy_loss3 = enc3(s3) + logits3 = head3(z_q3) + fl3 = masked_focal(logits3, raw_map, CH3_LOSS, gamma=FOCAL_GAMMA) + loss3 = fl3 + VQ_BETA * commit_loss3 + VQ_GAMMA * entropy_loss3 + + loss = loss1 + loss2 + loss3 + loss.backward() + torch.nn.utils.clip_grad_norm_( + list(enc1.parameters()) + list(enc2.parameters()) + list(enc3.parameters()) + + list(head1.parameters()) + list(head2.parameters()) + list(head3.parameters()), + max_norm=1.0, + ) + optimizer.step() + + total_loss += loss.item() + ch_losses[0] += loss1.item() + ch_losses[1] += loss2.item() + ch_losses[2] += loss3.item() + + n_batches = len(dl_train) + print( + f"[{epoch + 1:03d}] total={total_loss / n_batches:.4f} " + f"ch1={ch_losses[0] / n_batches:.4f} " + f"ch2={ch_losses[1] / n_batches:.4f} " + f"ch3={ch_losses[2] / n_batches:.4f}" + ) + + # ---- 检查点 & 验证 ---- + if (epoch + 1) % args.checkpoint == 0 or epoch + 1 == args.epochs: + metrics = validate(enc1, enc2, enc3, head1, head2, head3, dl_val) + print( + f" 验证 ch1_wall_recall={metrics['ch1_wall_recall']:.3f} " + f"ch2_recall={metrics['ch2_recall']} " + f"ch3_recall={metrics['ch3_recall']}" + ) + print( + f" codebook_entropy ch1={metrics['codebook_entropy'][0]:.3f} " + f"ch2={metrics['codebook_entropy'][1]:.3f} " + f"ch3={metrics['codebook_entropy'][2]:.3f}" + ) + + ts = datetime.now().strftime("%m%d-%H%M") + ckpt_path = f"result/pretrain_split/split-{epoch + 1}.pth" + torch.save({ + "epoch": epoch + 1, + "enc1": enc1.state_dict(), + "enc2": enc2.state_dict(), + "enc3": enc3.state_dict(), + "head1": head1.state_dict(), + "head2": head2.state_dict(), + "head3": head3.state_dict(), + "optimizer": optimizer.state_dict(), + "metrics": metrics, + "ts": ts, + }, ckpt_path) + print(f" Saved checkpoint: {ckpt_path}") + + # ---- 保存最终编码器权重(供联合训练加载) ---- + final_path = "result/pretrain_split/split_final.pth" + torch.save({ + "epoch": args.epochs, + "enc1": enc1.state_dict(), + "enc2": enc2.state_dict(), + "enc3": enc3.state_dict(), + # 解码头不迁移,不保存 + }, final_path) + print(f"\n预训练完成,编码器权重已保存至: {final_path}") + print("接下来运行联合训练(阶段 1 冻结热身):") + print(f" python -m ginka.train_vq --pretrain_split {final_path} --freeze_vq True") + + +if __name__ == "__main__": + train() diff --git a/ginka/train_vq.py b/ginka/train_vq.py index 3071252..82cfc9b 100644 --- a/ginka/train_vq.py +++ b/ginka/train_vq.py @@ -45,16 +45,24 @@ MAP_H = MAP_W = 13 FOCAL_GAMMA = 2.0 # focal loss 聚焦参数(越大越关注难例/稀有类别) WALL_MASK_RATIO = 0.8 -# VQ-VAE 超参 -VQ_L = 2 # summary token 数量(即 z 的序列长度) -VQ_K = 8 # codebook 大小 -VQ_D_Z = 128 # codebook 嵌入维度 -VQ_D_MODEL= 192 -VQ_NHEAD = 8 -VQ_LAYERS = 4 -VQ_DIM_FF = 512 -VQ_BETA = 0.5 # commit loss 权重 -VQ_GAMMA = 0.0 # entropy loss 权重 +# VQ-VAE 公共超参(三路编码器共用,方案 B 三通道分拆) +VQ_L = 2 # 每路码字序列长度(三路合计 L1+L2+L3 = 6) +VQ_K = 16 # codebook 大小 +VQ_D_Z = 64 # codebook 嵌入维度(三路保持一致,便于拼接) +VQ_BETA = 0.25 # commit loss 权重 +VQ_GAMMA = 0.1 # entropy loss 权重 + +# 各通道编码器配置 +CH1_D_MODEL = 128; CH1_NHEAD = 4 # 通道 1:空间骨架(floor+wall) +CH2_D_MODEL = 64; CH2_NHEAD = 4 # 通道 2:关卡门控 +CH3_D_MODEL = 64; CH3_NHEAD = 4 # 通道 3:收集资源 +VQ_LAYERS = 2 +VQ_DIM_FF = 256 + +# 通道专属损失计算范围(用于监控验证召回率) +CH1_LOSS = {1} +CH2_LOSS = {2, 9, 10} +CH3_LOSS = {3, 4, 5, 6, 7, 8} # MaskGIT 超参 MG_D_MODEL = 256 @@ -102,9 +110,12 @@ def parse_arguments(): parser.add_argument("--checkpoint", type=int, default=5, help="每隔多少 epoch 保存检查点并验证") parser.add_argument("--load_optim", type=bool, default=True) - parser.add_argument("--freeze_vq", type=bool, default=False, - help="(方案 D 阶段 1)冻结 VQ 编码器,仅训练 MaskGIT。" + parser.add_argument("--freeze_vq", type=bool, default=False, + help="(方案 B 阶段 1)冻结三路 VQ 编码器,仅训练 MaskGIT。" "适用于预训练权重加载后的热身阶段。") + parser.add_argument("--pretrain_split", type=str, default="", + help="(方案 B)三通道分拆预训练检查点路径;" + "指定后将从该检查点加载三路编码器初始权重。") return parser.parse_args() # --------------------------------------------------------------------------- @@ -352,7 +363,9 @@ def make_random_struct_cond() -> torch.Tensor: @torch.no_grad() def validate( - model_vq: GinkaVQVAE, + enc1: GinkaVQVAE, + enc2: GinkaVQVAE, + enc3: GinkaVQVAE, model_mg: GinkaMaskGIT, dataloader_val: DataLoader, tile_dict: dict, @@ -373,7 +386,8 @@ def validate( 场景5 (scene5_random) : 无数据集参照,随机稀疏墙壁种子 → 完全随机生成 列: random seed | z_rand×(N+1) """ - model_vq.eval() + for enc in [enc1, enc2, enc3]: + enc.eval() model_mg.eval() # 按 epoch 建立独立子文件夹,保留每次验证结果方便回溯 @@ -385,14 +399,33 @@ def validate( captured = {s: None for s in ('A', 'B', 'C', 'D')} # ── 计算 val loss + 捕获各子集样本 ────────────────────────────────────── + def _encode_three(s1, s2, s3): + """三路编码并拼接 z_q。""" + z_q1, _, _, vq1, _, _ = enc1(s1) + z_q2, _, _, vq2, _, _ = enc2(s2) + z_q3, _, _, vq3, _, _ = enc3(s3) + z_q = torch.cat([z_q1, z_q2, z_q3], dim=1) # [B, L1+L2+L3, d_z] + vq_loss = vq1 + vq2 + vq3 + return z_q, vq_loss + + def _sample_three(B_size): + """三路随机采样并拼接 z。""" + z1 = enc1.sample(B_size, device) + z2 = enc2.sample(B_size, device) + z3 = enc3.sample(B_size, device) + return torch.cat([z1, z2, z3], dim=1) + for batch in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm): raw_map = batch["raw_map"].to(device) # [B, 169] masked_map = batch["masked_map"].to(device) # [B, 169] target_map = batch["target_map"].to(device) # [B, 169] + s1 = batch["slice1"].to(device) + s2 = batch["slice2"].to(device) + s3 = batch["slice3"].to(device) subsets = batch["subset"] # list of str B = raw_map.shape[0] - z_q, _, _, vq_loss, _, _ = model_vq(raw_map) + z_q, vq_loss = _encode_three(s1, s2, s3) struct_cond_b = batch["struct_cond"].to(device) # [B, 4] logits = model_mg(masked_map, z_q, struct_cond=struct_cond_b) mask = (masked_map == MASK_TOKEN) @@ -419,7 +452,7 @@ def validate( def _rand_gens(cond_map, n): imgs = [] for i in range(n): - z_r = model_vq.sample(1, device) + z_r = _sample_three(1) gen = maskgit_generate(model_mg, z_r, init_map=cond_map) # struct_cond=None 无条件 imgs.append(label_image(make_map_image(gen[0], tile_dict), f"z_rand_{i + 1}")) return imgs @@ -428,7 +461,7 @@ def validate( def _rand_gens_with_struct(cond_map, n): imgs = [] for i in range(n): - z_r = model_vq.sample(1, device) + z_r = _sample_three(1) sc_r = make_random_struct_cond() # [1, 4] 随机合法标签 gen = maskgit_generate(model_mg, z_r, init_map=cond_map, struct_cond=sc_r) img = label_image(make_map_image(gen[0], tile_dict), f"z_rand_{i + 1}") @@ -539,15 +572,15 @@ def train(): print(f"Using device: {device}") args = parse_arguments() - # ---- 模型 ---- - model_vq = GinkaVQVAE( - num_classes=NUM_CLASSES, - L=VQ_L, K=VQ_K, d_z=VQ_D_Z, - d_model=VQ_D_MODEL, nhead=VQ_NHEAD, - num_layers=VQ_LAYERS, dim_ff=VQ_DIM_FF, - map_size=MAP_SIZE, + # ---- 三路编码器(方案 B 三通道分拆) ---- + _vq_common = dict( + num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_z=VQ_D_Z, + num_layers=VQ_LAYERS, dim_ff=VQ_DIM_FF, map_size=MAP_SIZE, beta=VQ_BETA, gamma=VQ_GAMMA, - ).to(device) + ) + enc1 = GinkaVQVAE(d_model=CH1_D_MODEL, nhead=CH1_NHEAD, **_vq_common).to(device) + enc2 = GinkaVQVAE(d_model=CH2_D_MODEL, nhead=CH2_NHEAD, **_vq_common).to(device) + enc3 = GinkaVQVAE(d_model=CH3_D_MODEL, nhead=CH3_NHEAD, **_vq_common).to(device) model_mg = GinkaMaskGIT( num_classes=NUM_CLASSES, @@ -559,11 +592,11 @@ def train(): struct_dropout=MG_STRUCT_DROPOUT, ).to(device) - vq_params = sum(p.numel() for p in model_vq.parameters()) + enc_params = sum(p.numel() for m in [enc1, enc2, enc3] for p in m.parameters()) mg_params = sum(p.numel() for p in model_mg.parameters()) - print(f"VQ-VAE 参数量: {vq_params:,} ({vq_params/1e6:.3f}M)") - print(f"MaskGIT 参数量: {mg_params:,} ({mg_params/1e6:.3f}M)") - print(f"Total 参数量: {vq_params+mg_params:,} ({(vq_params+mg_params)/1e6:.3f}M)") + print(f"Encoders 参数量(三路): {enc_params:,} ({enc_params/1e6:.3f}M)") + print(f"MaskGIT 参数量: {mg_params:,} ({mg_params/1e6:.3f}M)") + print(f"Total 参数量: {enc_params+mg_params:,} ({(enc_params+mg_params)/1e6:.3f}M)") # ---- 数据集 ---- dataset_train = GinkaVQDataset( @@ -587,19 +620,29 @@ def train(): num_workers=0, ) - # ---- 优化器(联合训练,两个模型共用一个 optimizer)---- - all_params = list(model_vq.parameters()) + list(model_mg.parameters()) + # ---- 优化器(联合训练,三路编码器 + MaskGIT 共用)---- + enc_params_list = list(enc1.parameters()) + list(enc2.parameters()) + list(enc3.parameters()) + all_params = enc_params_list + list(model_mg.parameters()) optimizer = optim.AdamW(all_params, lr=2e-4, weight_decay=1e-2) scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=args.epochs, eta_min=1e-6 ) - # ---- 续训 ---- + # ---- 权重加载 ---- start_epoch = 0 - if args.resume: + if args.pretrain_split: + # 从分拆预训练检查点加载三路编码器初始权重(阶段 1 冻结热身前) + ckpt = torch.load(args.pretrain_split, map_location=device) + enc1.load_state_dict(ckpt["enc1"]) + enc2.load_state_dict(ckpt["enc2"]) + enc3.load_state_dict(ckpt["enc3"]) + print(f"已加载分拆预训练编码器权重: {args.pretrain_split}") + elif args.resume: ckpt = torch.load(args.state, map_location=device) - model_vq.load_state_dict(ckpt["vq_state"], strict=False) - model_mg.load_state_dict(ckpt["mg_state"], strict=False) + enc1.load_state_dict(ckpt["enc1"], strict=False) + enc2.load_state_dict(ckpt["enc2"], strict=False) + enc3.load_state_dict(ckpt["enc3"], strict=False) + model_mg.load_state_dict(ckpt["mg_state"], strict=False) if args.load_optim and ckpt.get("optim_state") is not None: optimizer.load_state_dict(ckpt["optim_state"]) start_epoch = ckpt.get("epoch", 0) @@ -613,16 +656,18 @@ def train(): if img is not None: tile_dict[name] = img - # ---- 方案 D 阶段 1:冻结 VQ 编码器 ---- + # ---- 方案 B 阶段 1:冻结三路 VQ 编码器 ---- if args.freeze_vq: - for p in model_vq.parameters(): - p.requires_grad_(False) - print("VQ 编码器已冻结(方案 D 阶段 1:MaskGIT 热身)。") + for enc in [enc1, enc2, enc3]: + for p in enc.parameters(): + p.requires_grad_(False) + print("三路 VQ 编码器已冻结(阶段 1:MaskGIT 热身)。") # ---- 训练循环 ---- for epoch in tqdm(range(start_epoch, start_epoch + args.epochs), desc="Joint Training", disable=disable_tqdm): - model_vq.train() + for enc in [enc1, enc2, enc3]: + enc.train() model_mg.train() loss_total = 0.0 @@ -638,13 +683,23 @@ def train(): raw_map = batch["raw_map"].to(device) # [B, 169] masked_map = batch["masked_map"].to(device) # [B, 169] target_map = batch["target_map"].to(device) # [B, 169] + s1 = batch["slice1"].to(device) # 通道 1 切片 + s2 = batch["slice2"].to(device) # 通道 2 切片 + s3 = batch["slice3"].to(device) # 通道 3 切片 for s in batch["subset"]: subset_stats[s] = subset_stats.get(s, 0) + 1 # ---- 前向传播 ---- - # 1. VQ-VAE 编码真实地图 → z_q, z_e - z_q, z_e, _, vq_loss, commit_loss, entropy_loss = model_vq(raw_map) # z_q/z_e: [B, L, d_z] + # 1. 三路 VQ 编码器各自编码对应切片 → 拼接 z + z_q1, z_e1, _, vq_loss1, commit_loss1, entropy_loss1 = enc1(s1) + z_q2, z_e2, _, vq_loss2, commit_loss2, entropy_loss2 = enc2(s2) + z_q3, z_e3, _, vq_loss3, commit_loss3, entropy_loss3 = enc3(s3) + z_q = torch.cat([z_q1, z_q2, z_q3], dim=1) # [B, L1+L2+L3, d_z] + z_e = torch.cat([z_e1, z_e2, z_e3], dim=1) # [B, L1+L2+L3, d_z] + vq_loss = vq_loss1 + vq_loss2 + vq_loss3 + commit_loss = commit_loss1 + commit_loss2 + commit_loss3 + entropy_loss = entropy_loss1 + entropy_loss2 + entropy_loss3 # 2. MaskGIT 以掩码地图 + z + 结构标签预测原始 tile struct_cond = batch["struct_cond"].to(device) # [B, 4] @@ -655,23 +710,25 @@ def train(): ce_loss = focal_loss(logits.permute(0, 2, 1), target_map) masked_ce = (ce_loss * mask).sum() / (mask.sum() + 1e-6) - # 4. z 一致性约束(方案 A):将 MaskGIT 的 logits 经温度平滑后 - # 与 VQ 编码器的 tile embedding 做加权求和,得到软嵌入序列, - # 再送入编码器得到 z_pred_e,约束其与真实 z_e 对齐。 - # 梯度从 z_pred_e 回传到 MaskGIT 的 logits; - # VQ 参数在此路径上临时冻结(requires_grad=False), - # 确保编码器权重仅由真实地图路径(vq_loss)更新,不被一致性损失带偏。 - for p in model_vq.parameters(): - p.requires_grad_(False) + # 4. z 一致性约束(方案 A 扩展到三通道): + # MaskGIT logits 经温度平滑后与各编码器的 tile embedding 做加权求和, + # 得到软嵌入 → 各编码器再次编码 → z_pred_e_k 与真实 z_e_k 对齐。 + # 编码器权重在此路径上临时冻结,确保梯度仅回传至 MaskGIT。 + for enc in [enc1, enc2, enc3]: + for p in enc.parameters(): + p.requires_grad_(False) - soft_probs = F.softmax(logits / CONSIST_TEMP, dim=-1) # [B, H*W, V] - tile_emb = model_vq.tile_embedding.weight # [V, d_model] - soft_emb = soft_probs @ tile_emb # [B, H*W, d_model] - z_pred_e = model_vq.encode_soft(soft_emb) # [B, L, d_z] + soft_probs = F.softmax(logits / CONSIST_TEMP, dim=-1) # [B, H*W, V] + z_pred_e1 = enc1.encode_soft(soft_probs @ enc1.tile_embedding.weight) + z_pred_e2 = enc2.encode_soft(soft_probs @ enc2.tile_embedding.weight) + z_pred_e3 = enc3.encode_soft(soft_probs @ enc3.tile_embedding.weight) + z_pred_e = torch.cat([z_pred_e1, z_pred_e2, z_pred_e3], dim=1) consist_loss = F.mse_loss(z_pred_e, z_e.detach()) - - for p in model_vq.parameters(): - p.requires_grad_(True) + + if not args.freeze_vq: + for enc in [enc1, enc2, enc3]: + for p in enc.parameters(): + p.requires_grad_(True) # 5. 联合损失 loss = masked_ce + vq_loss + CONSIST_LAMBDA * consist_loss @@ -709,26 +766,31 @@ def train(): ckpt_path = f"result/joint/joint-{epoch + 1}.pth" torch.save({ "epoch": epoch + 1, - "vq_state": model_vq.state_dict(), + "enc1": enc1.state_dict(), + "enc2": enc2.state_dict(), + "enc3": enc3.state_dict(), "mg_state": model_mg.state_dict(), "optim_state":optimizer.state_dict(), }, ckpt_path) tqdm.write(f" 检查点已保存: {ckpt_path}") val_loss = validate( - model_vq, model_mg, dataloader_val, tile_dict, epoch + 1 + enc1, enc2, enc3, model_mg, dataloader_val, tile_dict, epoch + 1 ) tqdm.write( f"[Validate] Epoch {epoch + 1:4d} | Val Loss {val_loss:.5f}" ) # 恢复训练模式 - model_vq.train() + for enc in [enc1, enc2, enc3]: + enc.train() model_mg.train() print("训练结束。") torch.save({ "epoch": start_epoch + args.epochs, - "vq_state": model_vq.state_dict(), + "enc1": enc1.state_dict(), + "enc2": enc2.state_dict(), + "enc3": enc3.state_dict(), "mg_state": model_mg.state_dict(), }, "result/joint/joint_final.pth") diff --git a/ginka/utils.py b/ginka/utils.py index 29895b5..b604f9e 100644 --- a/ginka/utils.py +++ b/ginka/utils.py @@ -1,4 +1,5 @@ import torch +import torch.nn.functional as F import numpy as np def print_memory(device, tag=""): @@ -31,4 +32,46 @@ def nms_sampling(noise: np.ndarray, k: int, radius=2): result[y, x] = 1 return result - \ No newline at end of file + + +def masked_focal( + logits: torch.Tensor, + target: torch.Tensor, + tile_set: set, + gamma: float = 2.0, + eps: float = 1e-6, +) -> torch.Tensor: + """ + 通道专属掩码 Focal Loss:仅在 tile_set 中指定的 tile 位置计算损失。 + + Args: + logits: [B, H*W, num_classes] 解码头输出(未经 softmax) + target: [B, H*W] 完整地图 ground truth(整数 tile ID) + tile_set: set of int 本通道专属 tile 集合,其余位置损失权重为 0 + gamma: Focal Loss 聚焦参数 + eps: 数值稳定的分母偏置 + + Returns: + scalar tensor 通道专属掩码 Focal Loss + """ + B, S, C = logits.shape + + # 构造掩码:仅在专属 tile 位置为 True + mask = torch.zeros(B, S, dtype=torch.bool, device=logits.device) + for t in tile_set: + mask |= (target == t) + + if not mask.any(): + return logits.sum() * 0.0 # 保留计算图,返回零梯度 + + # Focal Loss(reduction='none') + ce = F.cross_entropy( + logits.view(-1, C), + target.view(-1), + reduction='none', + ).view(B, S) # [B, S] + + pt = torch.exp(-ce.detach()) # 正确类预测概率,stop-gradient + fl = (1.0 - pt) ** gamma * ce + + return (fl * mask).sum() / (mask.sum() + eps) diff --git a/train_full.sh b/train_full.sh index 6c6135e..5de90a2 100644 --- a/train_full.sh +++ b/train_full.sh @@ -1,10 +1,14 @@ #!/usr/bin/env bash # ============================================================================== -# 三阶段完整训练流水线 +# 三阶段完整训练流水线(方案 B:三通道分拆 VQ 编码器) # -# 阶段 0 VQ 编码器预训练 train_pretrain.py -# 阶段 1 MaskGIT 热身 train_vq.py --freeze_vq True +# 阶段 0 三通道分拆预训练 train_pretrain_split.py +# enc1(floor+wall) / enc2(+door+mob+entrance) / enc3(全图) +# 各自仅对本通道 tile 计算 masked Focal Loss +# 阶段 1 MaskGIT 热身 train_vq.py --pretrain_split --freeze_vq True +# 三路编码器权重冻结,仅训练 MaskGIT # 阶段 2 完整联合训练 train_vq.py +# 三路编码器 + MaskGIT 全量优化 # # 用法: # bash train_full.sh # 从头开始三阶段训练 @@ -19,10 +23,10 @@ set -euo pipefail TRAIN_DATA="ginka-dataset.json" EVAL_DATA="ginka-eval.json" -# 阶段 0:预训练 +# 阶段 0:三通道分拆预训练 P0_EPOCHS=100 P0_CHECKPOINT=10 -P0_FINAL="result/pretrain/pretrain_final.pth" +P0_FINAL="result/pretrain_split/split_final.pth" # 阶段 1:冻结编码器热身 P1_EPOCHS=50 @@ -68,14 +72,15 @@ die() { } # ------------------------------------------------------------------------------ -# 阶段 0:VQ 编码器预训练 +# 阶段 0:三通道分拆 VQ 编码器预训练 # ------------------------------------------------------------------------------ if [[ $START_PHASE -le 0 ]]; then - log "阶段 0 / 3 VQ 编码器预训练 (epochs=${P0_EPOCHS})" - python3 -u -m ginka.train_pretrain \ - --train "$TRAIN_DATA" \ - --validate "$EVAL_DATA" \ - --epochs "$P0_EPOCHS" \ + log "阶段 0 / 3 三通道分拆 VQ 预训练 (epochs=${P0_EPOCHS})" + mkdir -p result/pretrain_split + python3 -u -m ginka.train_pretrain_split \ + --train "$TRAIN_DATA" \ + --validate "$EVAL_DATA" \ + --epochs "$P0_EPOCHS" \ --checkpoint "$P0_CHECKPOINT" [[ -f "$P0_FINAL" ]] || die "阶段 0 未生成预期检查点:$P0_FINAL" @@ -86,24 +91,23 @@ else fi # ------------------------------------------------------------------------------ -# 阶段 1:MaskGIT 热身(VQ 编码器冻结) +# 阶段 1:MaskGIT 热身(三路 VQ 编码器冻结) # ------------------------------------------------------------------------------ if [[ $START_PHASE -le 1 ]]; then - log "阶段 1 / 3 MaskGIT 热身(VQ 冻结) (epochs=${P1_EPOCHS})" + log "阶段 1 / 3 MaskGIT 热身(三路 VQ 冻结) (epochs=${P1_EPOCHS})" + mkdir -p result/joint python3 -u -m ginka.train_vq \ - --train "$TRAIN_DATA" \ - --validate "$EVAL_DATA" \ - --resume True \ - --state "$P0_FINAL" \ - --load_optim False \ - --freeze_vq True \ - --epochs "$P1_EPOCHS" \ - --checkpoint "$P1_CHECKPOINT" + --train "$TRAIN_DATA" \ + --validate "$EVAL_DATA" \ + --pretrain_split "$P0_FINAL" \ + --load_optim False \ + --freeze_vq True \ + --epochs "$P1_EPOCHS" \ + --checkpoint "$P1_CHECKPOINT" - # 阶段 1 最后一个检查点 + # 取阶段 1 最后一个检查点,固定为阶段 2 入口 _P1_LAST=$(ls -t result/joint/joint-*.pth 2>/dev/null | head -1) [[ -n "$_P1_LAST" ]] || die "阶段 1 未生成任何检查点(result/joint/joint-*.pth)" - # 复制为阶段 1 固定终态,供阶段 2 加载 cp "$_P1_LAST" "$P1_FINAL" log "阶段 1 完成 → $P1_FINAL(来自 $_P1_LAST)" else @@ -112,18 +116,18 @@ else fi # ------------------------------------------------------------------------------ -# 阶段 2:完整联合训练 +# 阶段 2:完整联合训练(三路编码器 + MaskGIT 全量) # ------------------------------------------------------------------------------ if [[ $START_PHASE -le 2 ]]; then log "阶段 2 / 3 完整联合训练 (epochs=${P2_EPOCHS})" python3 -u -m ginka.train_vq \ - --train "$TRAIN_DATA" \ - --validate "$EVAL_DATA" \ - --resume True \ - --state "$P1_FINAL" \ - --load_optim False \ - --freeze_vq False \ - --epochs "$P2_EPOCHS" \ + --train "$TRAIN_DATA" \ + --validate "$EVAL_DATA" \ + --resume True \ + --state "$P1_FINAL" \ + --load_optim False \ + --freeze_vq False \ + --epochs "$P2_EPOCHS" \ --checkpoint "$P2_CHECKPOINT" log "阶段 2 完成"