mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-21 01:54:51 +08:00
feat: VQ 编码器分为三部分
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
dcb85dc1f8
commit
3874d4dd95
322
docs/vqvae-split-channel-design.md
Normal file
322
docs/vqvae-split-channel-design.md
Normal file
@ -0,0 +1,322 @@
|
||||
# VQ 编码器三通道分拆预训练设计文档
|
||||
|
||||
## 背景与问题诊断
|
||||
|
||||
### 核心问题
|
||||
|
||||
在当前 VQ-VAE + MaskGIT 联合训练方案(以及方案 A 闭环约束、方案 D 全图预训练)中,一个根本性的类别不均衡问题始终未被解决:
|
||||
|
||||
地图中约 70–85% 的格子为墙壁(tile=1)和空地(tile=0),而怪物(9)、门(2)、入口(10)、宝石(4/5/6)、钥匙(3)、药水(7)、道具(8)等功能性 tile 仅占少数。当编码器以完整地图为输入、重建损失对所有位置平等计算时:
|
||||
|
||||
- 编码器的梯度信号被墙壁/空地主导,特征向量主要编码了空间结构,而非功能性内容;
|
||||
- 解码(或生成)时,怪物/门/资源类 tile 的召回率远低于墙壁;
|
||||
- 即使方案 D 进行了全图预训练,由于损失仍是全位置等权 Focal Loss,预训练阶段本身就强化了这种偏向;
|
||||
- **已将损失由 CE 替换为 Focal Loss,但召回率仍无明显改善**——说明问题的根源在于梯度信号的来源比例,而非损失函数本身的形式;
|
||||
- 方案 A 的闭环一致性约束在 z 本身质量不足时,仅能强化已有的偏向,难以从根本上改善。
|
||||
|
||||
### 改进目标
|
||||
|
||||
通过**按语义层次将地图拆分为三个独立通道**,采用**累积式输入 + 通道专属损失**的设计:每个通道的编码器输入包含当前通道及所有低等级 tile(提供空间上下文),而解码头的损失**仅在本通道专属 tile 的位置计算,低等级 tile 的损失权重为 0**。这样既保证了编码器能在有意义的空间结构中学习(避免功能性 tile 沦为孤立散点),又迫使解码器必须正确预测本通道 tile 才能降低损失——无法靠拟合高频的墙壁/空地来刷低损失值。
|
||||
|
||||
---
|
||||
|
||||
## Tile 类型划分
|
||||
|
||||
根据 `data/src/shared.ts` 中的定义,将 15 个有效 tile 类型(不含 MASK=15)按游戏语义分为三个通道:
|
||||
|
||||
| 通道 | 编码器输入(切片内容) | 解码损失计算范围 | 语义含义 |
|
||||
| ------ | ---------------------------------------------------- | ------------------ | ------------------------------------ |
|
||||
| 通道 1 | floor(0) + wall(1) | {1}(仅墙壁) | 空间骨架(地形结构) |
|
||||
| 通道 2 | floor(0) + wall(1) + door(2) + mob(9) + entrance(10) | {2, 9, 10} | 关卡门控(交互元素,决定路径可达性) |
|
||||
| 通道 3 | 完整地图(所有 tile) | {3, 4, 5, 6, 7, 8} | 收集资源(奖励与道具) |
|
||||
|
||||
**通道划分的关键设计原则**:
|
||||
|
||||
- 通道 1 仅包含 floor 和 wall,编码器集中学习地图的空间骨架结构;
|
||||
- 通道 2 的输入**保留墙壁与空地作为空间上下文**,在此基础上叠加 door/mob/entrance,确保编码器能感知功能性 tile 在地图中的位置关系,而非将其视为无空间依附的孤立散点;
|
||||
- 通道 3 的输入为**完整地图**,编码器在包含骨架与关卡结构的完整上下文中学习资源 tile 的空间分布;
|
||||
- 预训练时每个解码头的损失**仅在本通道专属 tile 的位置计算,低等级 tile(floor、wall 及前级通道的 tile)损失权重为 0**——迫使编码器必须通过正确预测本通道 tile 来降低损失,无法靠拟合高频背景来规避优化压力。
|
||||
|
||||
---
|
||||
|
||||
## 整体架构
|
||||
|
||||
### 预训练阶段
|
||||
|
||||
```
|
||||
完整地图 [B, H*W]
|
||||
│
|
||||
├──► 切片 1:floor(0)+wall(1)(这已是全部内容,无需替换)
|
||||
│ │
|
||||
│ ▼
|
||||
│ Encoder_1 → VQ_1 → z_1 [B, L_1, d_z]
|
||||
│ │
|
||||
│ ▼
|
||||
│ DecodeHead_1 → logits [B, H*W, C]
|
||||
│ │
|
||||
│ ▼
|
||||
│ Loss_1:仅在 tile∈{1} 的位置计算 Focal Loss(仅墙壁,空地权重为 0)
|
||||
│
|
||||
├──► 切片 2:保留 floor(0)+wall(1)+door(2)+mob(9)+entry(10),其余→floor(0)
|
||||
│ │
|
||||
│ ▼
|
||||
│ Encoder_2 → VQ_2 → z_2 [B, L_2, d_z]
|
||||
│ │
|
||||
│ ▼
|
||||
│ DecodeHead_2 → logits [B, H*W, C]
|
||||
│ │
|
||||
│ ▼
|
||||
│ Loss_2:仅在 tile∈{2,9,10} 的位置计算 Focal Loss(floor/wall 权重为 0)
|
||||
│
|
||||
└──► 切片 3:完整地图(所有 tile,无需替换)
|
||||
│
|
||||
▼
|
||||
Encoder_3 → VQ_3 → z_3 [B, L_3, d_z]
|
||||
│
|
||||
▼
|
||||
DecodeHead_3 → logits [B, H*W, C]
|
||||
│
|
||||
▼
|
||||
Loss_3:仅在 tile∈{3,4,5,6,7,8} 的位置计算 Focal Loss(其余 tile 权重为 0)
|
||||
```
|
||||
|
||||
三路编码器**相互独立预训练**,每路的预训练损失:
|
||||
|
||||
$$\mathcal{L}_{pretrain}^{(k)} = \mathcal{L}_{FL}^{(k)} + \beta \cdot \mathcal{L}_{commit}^{(k)} + \gamma \cdot \mathcal{L}_{uniform}^{(k)}$$
|
||||
|
||||
其中 $\mathcal{L}_{FL}^{(k)}$ 为通道 $k$ 的通道专属掩码 Focal Loss(见下节)。
|
||||
|
||||
### 联合训练阶段
|
||||
|
||||
```
|
||||
完整地图 ──► 三路切片 ──► [Enc_1, Enc_2, Enc_3] ──► [z_1, z_2, z_3]
|
||||
│
|
||||
z = Concat([z_1, z_2, z_3], dim=1)
|
||||
│
|
||||
▼
|
||||
掩码地图 + z ──► MaskGIT (Cross-Attention) ──► 预测 logits ──► Focal Loss
|
||||
```
|
||||
|
||||
联合训练总损失(不含预训练解码头):
|
||||
|
||||
$$\mathcal{L}_{joint} = \mathcal{L}_{FL}^{MaskGIT} + \sum_{k=1}^{3} \left( \beta \cdot \mathcal{L}_{commit}^{(k)} + \gamma \cdot \mathcal{L}_{uniform}^{(k)} \right)$$
|
||||
|
||||
---
|
||||
|
||||
## 通道专属掩码 Focal Loss
|
||||
|
||||
这是方案的核心机制。对于通道 $k$,设其专属 tile 集合为 $\mathcal{T}_k$(不含低等级 tile),则损失计算为:
|
||||
|
||||
$$\mathcal{L}_{FL}^{(k)} = \frac{\sum_{i=1}^{H \times W} \mathbf{1}[y_i \in \mathcal{T}_k] \cdot \text{FL}(\hat{y}_i, y_i)}{\sum_{i=1}^{H \times W} \mathbf{1}[y_i \in \mathcal{T}_k] + \epsilon}$$
|
||||
|
||||
其中 $y_i$ 为真实 tile 类型,$\hat{y}_i$ 为解码头输出的 logits,$\text{FL}$ 为 Focal Loss。
|
||||
|
||||
**实现方式**(PyTorch 伪代码):
|
||||
|
||||
```python
|
||||
# 通道 2 示例(输入切片已包含 floor+wall 作为上下文)
|
||||
CHANNEL_2_TILES = {2, 9, 10} # door, mob, entrance
|
||||
|
||||
# target: 完整地图 ground truth,[B, H*W]
|
||||
# logits: DecodeHead_2 输出,[B, H*W, num_classes]
|
||||
|
||||
mask = torch.zeros_like(target, dtype=torch.bool)
|
||||
for t in CHANNEL_2_TILES:
|
||||
mask |= (target == t) # [B, H*W] bool,仅通道 2 专属 tile 的位置为 True
|
||||
|
||||
# focal_loss: reduction='none',返回 [B * H*W]
|
||||
fl = focal_loss(
|
||||
logits.view(-1, num_classes),
|
||||
target.view(-1),
|
||||
) # [B * H*W]
|
||||
fl = fl.view(B, -1) # [B, H*W]
|
||||
|
||||
loss_ch2 = (fl * mask).sum() / (mask.sum() + 1e-6)
|
||||
```
|
||||
|
||||
**为什么输入包含墙壁、但损失不计算墙壁**:通道 2 的切片中保留了 floor+wall,是为了给编码器提供空间结构上下文,使门/怪/入口的位置有意义(否则孤立散点难以形成有效表示)。但损失仅在 `{2, 9, 10}` 位置计算,确保梯度信号完全来自这三类 tile——编码器如果只靠拟合高频的墙壁/空地,解码头在 `{2, 9, 10}` 位置的损失无法降低,从而被迫学习功能性 tile 的空间分布。
|
||||
|
||||
---
|
||||
|
||||
## 切片构造规则
|
||||
|
||||
| 通道 | 切片中保留的 tile | 其余位置替换为 | 解码损失计算的位置(专属 tile 集合 $\mathcal{T}_k$) |
|
||||
| ---- | ----------------- | -------------- | ------------------------------------------------------- |
|
||||
| 1 | 0, 1 | —(无需替换) | {1}(仅墙壁;空地是墙壁的补集,能预测墙壁即能区分空地) |
|
||||
| 2 | 0, 1, 2, 9, 10 | 0(floor) | {2, 9, 10}(floor/wall 损失权重为 0) |
|
||||
| 3 | 全部(完整地图) | —(无需替换) | {3, 4, 5, 6, 7, 8}(其余 tile 损失权重为 0) |
|
||||
|
||||
---
|
||||
|
||||
## 编码器架构设计
|
||||
|
||||
### 三路独立编码器
|
||||
|
||||
三个编码器均复用现有的 `GinkaVQVAE` 类,配置略有差异:
|
||||
|
||||
| 参数 | Encoder_1(结构骨架) | Encoder_2(关卡门控) | Encoder_3(收集资源) |
|
||||
| --------------- | --------------------- | --------------------- | --------------------- |
|
||||
| `L`(码字数) | 2 | 2 | 2 |
|
||||
| `K`(码本大小) | 16 | 16 | 16 |
|
||||
| `d_z` | 64 | 64 | 64 |
|
||||
| `d_model` | 128 | 64 | 64 |
|
||||
| `num_layers` | 2 | 2 | 2 |
|
||||
|
||||
- Encoder_1 处理高频 tile,适当加大 `d_model`;
|
||||
- Encoder_2 和 Encoder_3 的功能性 tile 稀疏,信息量较小,可使用较小的 `d_model`;
|
||||
- 三路 `d_z` 保持一致,以便拼接后维度齐整;
|
||||
- 总参数量估算:Encoder_1 ~400K + Encoder_2 ~150K + Encoder_3 ~150K ≈ **700K**,在 1M 预算内。
|
||||
|
||||
> 考虑到训练集数量较少,可以考虑适当降低 K 和 L 的值,避免模型死记硬背,也可以防止训练集没有覆盖全部所有情况。1M 的参数量仅做估计,可以先尝试较大的参数量,如出现过拟合再降低。
|
||||
|
||||
### 联合训练时的 z 拼接
|
||||
|
||||
$$z = \text{Concat}([z_1, z_2, z_3], \dim=1) \in \mathbb{R}^{B \times (L_1+L_2+L_3) \times d_z}$$
|
||||
|
||||
以各通道 `L=2` 为例,总 memory 长度为 6,与当前 MaskGIT Cross-Attention 的 memory 规模(原来 L=2)相比略有增加,但绝对长度仍很小,不影响计算效率。
|
||||
|
||||
---
|
||||
|
||||
## 预训练流程
|
||||
|
||||
### 解码头复用
|
||||
|
||||
预训练时三路各自使用一个 `VQDecodeHead` 实例(现有类,`num_classes=16`),预训练结束后整体丢弃。解码头参数不迁移到联合训练阶段。
|
||||
|
||||
### 训练脚本
|
||||
|
||||
新增 `ginka/train_pretrain_split.py`(独立于现有的 `train_pretrain.py`):
|
||||
|
||||
```python
|
||||
# 伪代码结构
|
||||
for epoch in ...:
|
||||
for batch in dataloader:
|
||||
raw_map = batch["raw_map"] # [B, H*W] 完整地图
|
||||
|
||||
# ─── 通道 1 ───
|
||||
slice1 = make_slice(raw_map, keep={0, 1}) # floor+wall,切片即完整输入
|
||||
z_q1, z_e1, _, vq_loss1, *_ = enc1(slice1)
|
||||
logits1 = head1(z_q1)
|
||||
fl1 = masked_focal(logits1, raw_map, tile_set={1}) # 仅对 wall 计损失
|
||||
loss1 = fl1 + vq_loss1
|
||||
|
||||
# ─── 通道 2 ───
|
||||
slice2 = make_slice(raw_map, keep={0, 1, 2, 9, 10}) # 保留 wall/floor 作为上下文
|
||||
z_q2, z_e2, _, vq_loss2, *_ = enc2(slice2)
|
||||
logits2 = head2(z_q2)
|
||||
fl2 = masked_focal(logits2, raw_map, tile_set={2, 9, 10}) # 仅对专属 tile 计损失
|
||||
loss2 = fl2 + vq_loss2
|
||||
|
||||
# ─── 通道 3 ───
|
||||
slice3 = raw_map # 完整地图,无需切片
|
||||
z_q3, z_e3, _, vq_loss3, *_ = enc3(slice3)
|
||||
logits3 = head3(z_q3)
|
||||
fl3 = masked_focal(logits3, raw_map, tile_set={3, 4, 5, 6, 7, 8}) # 仅对专属 tile 计损失
|
||||
loss3 = fl3 + vq_loss3
|
||||
|
||||
total = loss1 + loss2 + loss3
|
||||
total.backward()
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
三路编码器可以**同步训练**(同一 optimizer),也可以分别独立训练——独立训练更灵活,可以对收敛速度差异大的通道单独调参。
|
||||
|
||||
### 预训练监控指标
|
||||
|
||||
| 指标 | 说明 |
|
||||
| ----------------------- | ---------------------------------------------------------------- |
|
||||
| 通道 1 wall 位置准确率 | Encoder_1 能否正确重建墙壁分布 |
|
||||
| 通道 2 功能 tile 召回率 | Encoder_2 对 door/mob/entrance 各类的召回,应 > 50%(稀疏 tile) |
|
||||
| 通道 3 资源 tile 召回率 | Encoder_3 对各资源类的召回 |
|
||||
| codebook 使用熵(各路) | 各通道 codebook 是否均匀使用,避免 collapse |
|
||||
|
||||
由于通道 2/3 的 tile 在每张地图中数量极少(典型地图中 door/mob/entrance 合计约 10–20 格,资源合计约 10–15 格),召回率指标比准确率更有意义。
|
||||
|
||||
---
|
||||
|
||||
## 联合训练流程
|
||||
|
||||
### 三阶段训练
|
||||
|
||||
| 阶段 | 模型状态 | 目标 | 建议轮数 |
|
||||
| -------------------- | ----------------------------------------- | ------------------------------------------------ | ------------ |
|
||||
| 阶段 0:分通道预训练 | 三路 Encoder + 三路 DecodeHead | 各通道 Focal Loss 收敛,功能 tile 召回率达到目标 | 30–60 epoch |
|
||||
| 阶段 1:冻结热身 | 三路 Encoder 冻结 + MaskGIT 全参训练 | MaskGIT 适应三路 z 的联合分布 | 20–40 epoch |
|
||||
| 阶段 2:完整联合训练 | 全部参数解冻,Encoder 使用较小 LR(×0.1) | 端到端联合优化 | 正常训练轮数 |
|
||||
|
||||
### 联合训练数据集
|
||||
|
||||
`GinkaJointDataset` 需扩展为同时提供三路切片:
|
||||
|
||||
```python
|
||||
# 返回字典新增字段
|
||||
{
|
||||
"raw_map": ..., # [H*W] 完整地图(VQ 编码器输入)
|
||||
"slice1": ..., # [H*W] 通道 1 切片
|
||||
"slice2": ..., # [H*W] 通道 2 切片
|
||||
"slice3": ..., # [H*W] 通道 3 切片
|
||||
"masked_map": ..., # [H*W] MaskGIT 输入(掩码后地图)
|
||||
"target_map": ..., # [H*W] MaskGIT CE ground truth
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 推理时的 z 采样
|
||||
|
||||
推理时三路编码器均独立采样,无需用户输入:
|
||||
|
||||
```python
|
||||
# 完全随机生成
|
||||
z1 = enc1.sample(B, device) # [B, L1, d_z]
|
||||
z2 = enc2.sample(B, device) # [B, L2, d_z]
|
||||
z3 = enc3.sample(B, device) # [B, L3, d_z]
|
||||
z = torch.cat([z1, z2, z3], dim=1) # [B, L1+L2+L3, d_z]
|
||||
```
|
||||
|
||||
**分通道条件控制**(可选扩展):
|
||||
|
||||
| 场景 | 通道 1 z 来源 | 通道 2 z 来源 | 通道 3 z 来源 |
|
||||
| ------------ | ------------- | ------------- | ------------- |
|
||||
| 完全随机生成 | 随机采样 | 随机采样 | 随机采样 |
|
||||
| 指定墙壁布局 | 用户地图编码 | 随机采样 | 随机采样 |
|
||||
| 指定关卡结构 | 随机采样 | 参考图编码 | 随机采样 |
|
||||
| 风格迁移 | 参考图编码 | 参考图编码 | 随机采样 |
|
||||
|
||||
---
|
||||
|
||||
## 与现有方案的关系
|
||||
|
||||
| 方案 | 与本方案的关系 |
|
||||
| ------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| 方案 A | 兼容。由于通道 2/3 的编码器输入包含墙壁上下文(而非孤立散点),其 z 具备稳定的空间语义,一致性约束可同样作用于三个通道:通道 1 约束骨架结构,通道 2 约束关卡元素分布,通道 3 约束资源分布 |
|
||||
| 方案 D | 本方案**替代并强化**方案 D 的预训练思路:方案 D 是全图等权 Focal Loss 预训练,本方案通过累积式输入 + 通道专属损失从根本上解决了类别不均衡问题 |
|
||||
| 方案 C | 兼容。多阶段生成(先墙后门后资源)可以将通道 1/2/3 的 z 分别作为各阶段的生成条件 |
|
||||
|
||||
---
|
||||
|
||||
## 超参数建议
|
||||
|
||||
| 参数 | 建议初始值 | 备注 |
|
||||
| -------------------------- | ---------- | ---------------------------------------- |
|
||||
| 各通道 `L`(码字数) | 2 | 三路合计 6,可视效果适当扩大 |
|
||||
| 各通道 `K`(码本大小) | 16 | 通道 2/3 可减小到 8(tile 种类少) |
|
||||
| `d_z` | 64 | 三路保持一致,便于拼接 |
|
||||
| `β`(commit loss 权重) | 0.25 | 同现有配置 |
|
||||
| `γ`(uniform loss 权重) | 0.1 | 通道 2/3 码本小,可适当增大到 0.2 |
|
||||
| 预训练 epoch | 30–60 | 以功能 tile 召回率达标为准,不以轮数为限 |
|
||||
| 联合训练 Encoder LR 缩放比 | 0.1 | 阶段 2 解冻后使用较小 LR 微调 |
|
||||
| z dropout 概率(联合训练) | 0.1–0.2 | 三路 z 各自独立 dropout |
|
||||
|
||||
---
|
||||
|
||||
## 实施步骤
|
||||
|
||||
- [ ] 在 `ginka/dataset.py` 中实现 `make_slice(map, keep_set)` 辅助函数,生成三路切片
|
||||
- [ ] 扩展 `GinkaJointDataset.__getitem__`,新增 `slice1/slice2/slice3` 字段
|
||||
- [ ] 在 `ginka/vqvae/model.py` 中确认 `GinkaVQVAE` 可独立实例化三次(无全局状态)
|
||||
- [ ] 实现 `masked_focal(logits, target, tile_set)` 工具函数(`ginka/utils.py`)
|
||||
- [ ] 新增 `ginka/train_pretrain_split.py` 预训练脚本(支持三路同步或分路训练)
|
||||
- [ ] 修改 `ginka/train_vq.py`(联合训练脚本),支持加载三路编码器权重并拼接 z
|
||||
- [ ] 修改 `GinkaMaskGIT.forward()` 以接受 `[B, L1+L2+L3, d_z]` 的拼接 z(Cross-Attention memory)
|
||||
- [ ] 添加联合训练监控:各通道 codebook 使用熵、功能 tile 召回率
|
||||
@ -459,8 +459,12 @@ class GinkaVQDataset(Dataset):
|
||||
|
||||
struct_cond = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer])
|
||||
|
||||
raw_t = torch.LongTensor(raw_flat)
|
||||
return {
|
||||
"raw_map": torch.LongTensor(raw_flat), # VQ-VAE 编码器输入
|
||||
"raw_map": raw_t, # VQ-VAE 编码器输入
|
||||
"slice1": make_slice(raw_t, {0, 1}), # 通道 1:floor+wall
|
||||
"slice2": make_slice(raw_t, {0, 1, 2, 9, 10}),# 通道 2:floor+wall+门+怪+入口
|
||||
"slice3": raw_t.clone(), # 通道 3:完整地图
|
||||
"masked_map": torch.LongTensor(masked_np), # MaskGIT 输入
|
||||
"target_map": torch.LongTensor(raw_flat.copy()), # CE loss ground truth
|
||||
"subset": subset, # 调试/统计用
|
||||
@ -468,6 +472,78 @@ class GinkaVQDataset(Dataset):
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# make_slice:按保留集合切割地图,其余位置替换为 floor(0)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_slice(map_flat: torch.Tensor, keep_set: set) -> torch.Tensor:
|
||||
"""
|
||||
从完整地图中只保留 keep_set 中的 tile 类型,其余位置替换为 floor(0)。
|
||||
|
||||
Args:
|
||||
map_flat: LongTensor [H*W] 完整地图 tile ID 序列
|
||||
keep_set: set of int 需要保留的 tile 类型集合
|
||||
|
||||
Returns:
|
||||
LongTensor [H*W] 切片后的地图(非保留 tile 位置值为 0)
|
||||
"""
|
||||
out = map_flat.clone()
|
||||
mask = torch.zeros_like(out, dtype=torch.bool)
|
||||
for t in keep_set:
|
||||
mask |= (out == t)
|
||||
out[~mask] = 0
|
||||
return out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GinkaSplitDataset:三通道分拆预训练专用数据集
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class GinkaSplitDataset(Dataset):
|
||||
"""
|
||||
三通道分拆预训练(方案 B)专用数据集。
|
||||
|
||||
每个样本只提供完整地图及其三路切片,不做 MaskGIT 掩码处理。
|
||||
切片按累积式设计:
|
||||
slice1 = floor(0) + wall(1)
|
||||
slice2 = floor(0) + wall(1) + door(2) + mob(9) + entrance(10)
|
||||
slice3 = 完整地图(所有 tile)
|
||||
|
||||
返回 dict:
|
||||
raw_map: LongTensor [H*W] 完整原始地图
|
||||
slice1: LongTensor [H*W] 通道 1 切片(floor+wall)
|
||||
slice2: LongTensor [H*W] 通道 2 切片(floor+wall+门+怪+入口)
|
||||
slice3: LongTensor [H*W] 通道 3 切片(完整地图)
|
||||
"""
|
||||
|
||||
def __init__(self, data_path: str):
|
||||
self.data = load_data(data_path)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
arr = np.array(item['map'], dtype=np.int64) # [H, W]
|
||||
|
||||
# 随机旋转 / 翻转数据增强
|
||||
if np.random.rand() > 0.5:
|
||||
k = np.random.randint(1, 4)
|
||||
arr = np.rot90(arr, k).copy()
|
||||
if np.random.rand() > 0.5:
|
||||
arr = np.fliplr(arr).copy()
|
||||
if np.random.rand() > 0.5:
|
||||
arr = np.flipud(arr).copy()
|
||||
|
||||
raw = torch.LongTensor(arr.reshape(-1)) # [H*W]
|
||||
return {
|
||||
"raw_map": raw,
|
||||
"slice1": make_slice(raw, {0, 1}),
|
||||
"slice2": make_slice(raw, {0, 1, 2, 9, 10}),
|
||||
"slice3": raw.clone(),
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
data_path = os.path.join(os.path.dirname(__file__), '..', 'data', 'ginka-dataset.json')
|
||||
|
||||
@ -83,12 +83,13 @@ class GinkaMaskGIT(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
map: [B, H*W] 掩码后的地图 token 序列(MASK token = 15)
|
||||
z: [B, L, d_z] VQ-VAE 量化后的离散隐变量
|
||||
struct_cond: [B, 4] 结构标签 LongTensor,顺序为
|
||||
[cond_sym, cond_room, cond_branch, cond_outer];
|
||||
为 None 时等价于全 null(无条件模式)
|
||||
dropout_struct: bool 强制将所有结构标签替换为 null(推理时无条件生成)
|
||||
map: [B, H*W] 掩码后的地图 token 序列(MASK token = 15)
|
||||
z: [B, L_total, d_z] VQ-VAE 量化后的离散隐变量;
|
||||
方案 B 中 L_total = L1+L2+L3(三路 z 拼接)
|
||||
struct_cond: [B, 4] 结构标签 LongTensor,顺序为
|
||||
[cond_sym, cond_room, cond_branch, cond_outer];
|
||||
为 None 时等价于全 null(无条件模式)
|
||||
dropout_struct: bool 强制将所有结构标签替换为 null(推理时无条件生成)
|
||||
|
||||
Returns:
|
||||
logits: [B, H*W, num_classes]
|
||||
|
||||
363
ginka/train_pretrain_split.py
Normal file
363
ginka/train_pretrain_split.py
Normal file
@ -0,0 +1,363 @@
|
||||
"""
|
||||
三通道分拆预训练脚本(方案 B)
|
||||
|
||||
三路编码器各自负责一个语义通道:
|
||||
通道 1:空间骨架(floor+wall),损失仅计算 wall(1) 位置
|
||||
通道 2:关卡门控(floor+wall+door+mob+entrance),损失仅计算 {2,9,10} 位置
|
||||
通道 3:收集资源(完整地图),损失仅计算 {3,4,5,6,7,8} 位置
|
||||
|
||||
预训练完成后保存各通道编码器权重(不含解码头),
|
||||
供联合训练脚本 train_vq.py 加载并拼接 z。
|
||||
|
||||
用法示例:
|
||||
python -m ginka.train_pretrain_split
|
||||
python -m ginka.train_pretrain_split --resume True --state result/pretrain_split/split-10.pth
|
||||
# 预训练完成后指定权重路径启动联合训练:
|
||||
python -m ginka.train_vq --pretrain_split result/pretrain_split/split_final.pth
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from .vqvae.model import GinkaVQVAE, VQDecodeHead
|
||||
from .dataset import GinkaSplitDataset
|
||||
from .utils import masked_focal
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 超参数
|
||||
# ---------------------------------------------------------------------------
|
||||
BATCH_SIZE = 64
|
||||
NUM_CLASSES = 16
|
||||
MAP_SIZE = 13 * 13
|
||||
FOCAL_GAMMA = 2.0
|
||||
|
||||
# 通道 1:空间骨架(floor+wall)
|
||||
CH1_KEEP = {0, 1} # 编码器输入保留的 tile
|
||||
CH1_LOSS = {1} # 损失计算范围(仅 wall)
|
||||
CH1_D_MODEL = 128
|
||||
CH1_NHEAD = 4
|
||||
|
||||
# 通道 2:关卡门控
|
||||
CH2_KEEP = {0, 1, 2, 9, 10}
|
||||
CH2_LOSS = {2, 9, 10}
|
||||
CH2_D_MODEL = 64
|
||||
CH2_NHEAD = 4
|
||||
|
||||
# 通道 3:收集资源
|
||||
CH3_KEEP = None # 完整地图,无需切片
|
||||
CH3_LOSS = {3, 4, 5, 6, 7, 8}
|
||||
CH3_D_MODEL = 64
|
||||
CH3_NHEAD = 4
|
||||
|
||||
# 三路共用的 VQ 超参
|
||||
VQ_L = 2
|
||||
VQ_K = 16
|
||||
VQ_D_Z = 64
|
||||
VQ_LAYERS = 2
|
||||
VQ_DIM_FF = 256
|
||||
VQ_BETA = 0.25 # commit loss 权重
|
||||
VQ_GAMMA = 0.1 # entropy loss 权重
|
||||
|
||||
# 解码头超参(三路共用相同规格)
|
||||
DH_NHEAD = 4
|
||||
DH_DIM_FF = 256
|
||||
DH_LAYERS = 2
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 设备
|
||||
# ---------------------------------------------------------------------------
|
||||
device = torch.device(
|
||||
"cuda:1" if torch.cuda.is_available()
|
||||
else "mps" if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
os.makedirs("result/pretrain_split", exist_ok=True)
|
||||
|
||||
disable_tqdm = not sys.stdout.isatty()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 参数解析
|
||||
# ---------------------------------------------------------------------------
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="三通道分拆 VQ 编码器预训练(方案 B)")
|
||||
parser.add_argument("--resume", type=bool, default=False)
|
||||
parser.add_argument("--state", type=str, default="result/pretrain_split/split-10.pth",
|
||||
help="续训时加载的检查点路径")
|
||||
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
||||
parser.add_argument("--validate", type=str, default="ginka-eval.json")
|
||||
parser.add_argument("--epochs", type=int, default=60)
|
||||
parser.add_argument("--checkpoint", type=int, default=5,
|
||||
help="每隔多少 epoch 保存检查点并输出验证指标")
|
||||
parser.add_argument("--load_optim", type=bool, default=True)
|
||||
return parser.parse_args()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 验证:各通道专属 tile 召回率 + codebook 使用熵
|
||||
# ---------------------------------------------------------------------------
|
||||
@torch.no_grad()
|
||||
def validate(
|
||||
enc1, enc2, enc3,
|
||||
head1, head2, head3,
|
||||
dataloader_val: DataLoader,
|
||||
) -> dict:
|
||||
for m in [enc1, enc2, enc3, head1, head2, head3]:
|
||||
m.eval()
|
||||
|
||||
# 每类 tile 的 tp / gt 计数
|
||||
ch1_tp, ch1_gt = 0, 0 # wall(1)
|
||||
ch2_tp = {t: 0 for t in CH2_LOSS} # {2,9,10}
|
||||
ch2_gt = {t: 0 for t in CH2_LOSS}
|
||||
ch3_tp = {t: 0 for t in CH3_LOSS} # {3,4,5,6,7,8}
|
||||
ch3_gt = {t: 0 for t in CH3_LOSS}
|
||||
|
||||
# codebook 使用频次(用于熵估算)
|
||||
codebook_counts = [
|
||||
torch.zeros(VQ_K, dtype=torch.long), # 通道 1
|
||||
torch.zeros(VQ_K, dtype=torch.long), # 通道 2
|
||||
torch.zeros(VQ_K, dtype=torch.long), # 通道 3
|
||||
]
|
||||
|
||||
for batch in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm):
|
||||
raw_map = batch["raw_map"].to(device)
|
||||
s1 = batch["slice1"].to(device)
|
||||
s2 = batch["slice2"].to(device)
|
||||
s3 = batch["slice3"].to(device)
|
||||
|
||||
# 通道 1
|
||||
z_q1, _, idx1, _, _, _ = enc1(s1)
|
||||
logits1 = head1(z_q1)
|
||||
pred1 = logits1.argmax(dim=-1) # [B, H*W]
|
||||
wall_m = (raw_map == 1)
|
||||
ch1_tp += (pred1[wall_m] == 1).sum().item()
|
||||
ch1_gt += wall_m.sum().item()
|
||||
for code in idx1.view(-1).cpu():
|
||||
codebook_counts[0][code] += 1
|
||||
|
||||
# 通道 2
|
||||
z_q2, _, idx2, _, _, _ = enc2(s2)
|
||||
logits2 = head2(z_q2)
|
||||
pred2 = logits2.argmax(dim=-1)
|
||||
for t in CH2_LOSS:
|
||||
m = (raw_map == t)
|
||||
ch2_tp[t] += (pred2[m] == t).sum().item()
|
||||
ch2_gt[t] += m.sum().item()
|
||||
for code in idx2.view(-1).cpu():
|
||||
codebook_counts[1][code] += 1
|
||||
|
||||
# 通道 3
|
||||
z_q3, _, idx3, _, _, _ = enc3(s3)
|
||||
logits3 = head3(z_q3)
|
||||
pred3 = logits3.argmax(dim=-1)
|
||||
for t in CH3_LOSS:
|
||||
m = (raw_map == t)
|
||||
ch3_tp[t] += (pred3[m] == t).sum().item()
|
||||
ch3_gt[t] += m.sum().item()
|
||||
for code in idx3.view(-1).cpu():
|
||||
codebook_counts[2][code] += 1
|
||||
|
||||
def _entropy(counts):
|
||||
"""估算 codebook 使用熵(bits)。"""
|
||||
counts = counts.float() + 1e-8
|
||||
p = counts / counts.sum()
|
||||
return float(-(p * torch.log2(p)).sum())
|
||||
|
||||
metrics = {
|
||||
"ch1_wall_recall": ch1_tp / max(ch1_gt, 1),
|
||||
"ch2_recall": {t: ch2_tp[t] / max(ch2_gt[t], 1) for t in CH2_LOSS},
|
||||
"ch3_recall": {t: ch3_tp[t] / max(ch3_gt[t], 1) for t in CH3_LOSS},
|
||||
"codebook_entropy": [_entropy(c) for c in codebook_counts],
|
||||
}
|
||||
return metrics
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 主训练函数
|
||||
# ---------------------------------------------------------------------------
|
||||
def train():
|
||||
print(f"Using device: {device}")
|
||||
args = parse_arguments()
|
||||
|
||||
# ---- 三路编码器 ----
|
||||
enc1 = GinkaVQVAE(
|
||||
num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_z=VQ_D_Z,
|
||||
d_model=CH1_D_MODEL, nhead=CH1_NHEAD, num_layers=VQ_LAYERS,
|
||||
dim_ff=VQ_DIM_FF, beta=VQ_BETA, gamma=VQ_GAMMA,
|
||||
).to(device)
|
||||
|
||||
enc2 = GinkaVQVAE(
|
||||
num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_z=VQ_D_Z,
|
||||
d_model=CH2_D_MODEL, nhead=CH2_NHEAD, num_layers=VQ_LAYERS,
|
||||
dim_ff=VQ_DIM_FF, beta=VQ_BETA, gamma=VQ_GAMMA,
|
||||
).to(device)
|
||||
|
||||
enc3 = GinkaVQVAE(
|
||||
num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_z=VQ_D_Z,
|
||||
d_model=CH3_D_MODEL, nhead=CH3_NHEAD, num_layers=VQ_LAYERS,
|
||||
dim_ff=VQ_DIM_FF, beta=VQ_BETA, gamma=VQ_GAMMA,
|
||||
).to(device)
|
||||
|
||||
# ---- 三路解码头(预训练专用,训练后丢弃)----
|
||||
head1 = VQDecodeHead(
|
||||
num_classes=NUM_CLASSES, d_z=VQ_D_Z, map_size=MAP_SIZE,
|
||||
nhead=DH_NHEAD, dim_ff=DH_DIM_FF, num_layers=DH_LAYERS,
|
||||
).to(device)
|
||||
|
||||
head2 = VQDecodeHead(
|
||||
num_classes=NUM_CLASSES, d_z=VQ_D_Z, map_size=MAP_SIZE,
|
||||
nhead=DH_NHEAD, dim_ff=DH_DIM_FF, num_layers=DH_LAYERS,
|
||||
).to(device)
|
||||
|
||||
head3 = VQDecodeHead(
|
||||
num_classes=NUM_CLASSES, d_z=VQ_D_Z, map_size=MAP_SIZE,
|
||||
nhead=DH_NHEAD, dim_ff=DH_DIM_FF, num_layers=DH_LAYERS,
|
||||
).to(device)
|
||||
|
||||
# ---- 优化器(三路同步训练) ----
|
||||
optimizer = optim.AdamW(
|
||||
list(enc1.parameters()) + list(enc2.parameters()) + list(enc3.parameters()) +
|
||||
list(head1.parameters()) + list(head2.parameters()) + list(head3.parameters()),
|
||||
lr=1e-3,
|
||||
weight_decay=1e-4,
|
||||
)
|
||||
|
||||
start_epoch = 0
|
||||
|
||||
# ---- 续训 ----
|
||||
if args.resume:
|
||||
ckpt = torch.load(args.state, map_location=device)
|
||||
enc1.load_state_dict(ckpt["enc1"])
|
||||
enc2.load_state_dict(ckpt["enc2"])
|
||||
enc3.load_state_dict(ckpt["enc3"])
|
||||
head1.load_state_dict(ckpt["head1"])
|
||||
head2.load_state_dict(ckpt["head2"])
|
||||
head3.load_state_dict(ckpt["head3"])
|
||||
if args.load_optim and "optimizer" in ckpt:
|
||||
optimizer.load_state_dict(ckpt["optimizer"])
|
||||
start_epoch = ckpt.get("epoch", 0)
|
||||
print(f"Resumed from epoch {start_epoch}: {args.state}")
|
||||
|
||||
# ---- 数据集 ----
|
||||
ds_train = GinkaSplitDataset(args.train)
|
||||
ds_val = GinkaSplitDataset(args.validate)
|
||||
dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
|
||||
dl_val = DataLoader(ds_val, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
|
||||
|
||||
print(f"训练集大小: {len(ds_train)},验证集大小: {len(ds_val)}")
|
||||
|
||||
total_params = (
|
||||
sum(p.numel() for p in enc1.parameters()) +
|
||||
sum(p.numel() for p in enc2.parameters()) +
|
||||
sum(p.numel() for p in enc3.parameters())
|
||||
)
|
||||
print(f"编码器总参数量(三路): {total_params:,} ({total_params / 1e6:.3f}M)")
|
||||
|
||||
# ---- 训练循环 ----
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
for m in [enc1, enc2, enc3, head1, head2, head3]:
|
||||
m.train()
|
||||
|
||||
total_loss = 0.0
|
||||
ch_losses = [0.0, 0.0, 0.0]
|
||||
|
||||
for batch in tqdm(dl_train, desc=f"Epoch {epoch + 1}/{args.epochs}", disable=disable_tqdm):
|
||||
raw_map = batch["raw_map"].to(device)
|
||||
s1 = batch["slice1"].to(device)
|
||||
s2 = batch["slice2"].to(device)
|
||||
s3 = batch["slice3"].to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# ─── 通道 1 ───
|
||||
z_q1, _, _, vq_loss1, commit_loss1, entropy_loss1 = enc1(s1)
|
||||
logits1 = head1(z_q1) # [B, H*W, C]
|
||||
fl1 = masked_focal(logits1, raw_map, CH1_LOSS, gamma=FOCAL_GAMMA)
|
||||
loss1 = fl1 + VQ_BETA * commit_loss1 + VQ_GAMMA * entropy_loss1
|
||||
|
||||
# ─── 通道 2 ───
|
||||
z_q2, _, _, vq_loss2, commit_loss2, entropy_loss2 = enc2(s2)
|
||||
logits2 = head2(z_q2)
|
||||
fl2 = masked_focal(logits2, raw_map, CH2_LOSS, gamma=FOCAL_GAMMA)
|
||||
loss2 = fl2 + VQ_BETA * commit_loss2 + VQ_GAMMA * entropy_loss2
|
||||
|
||||
# ─── 通道 3 ───
|
||||
z_q3, _, _, vq_loss3, commit_loss3, entropy_loss3 = enc3(s3)
|
||||
logits3 = head3(z_q3)
|
||||
fl3 = masked_focal(logits3, raw_map, CH3_LOSS, gamma=FOCAL_GAMMA)
|
||||
loss3 = fl3 + VQ_BETA * commit_loss3 + VQ_GAMMA * entropy_loss3
|
||||
|
||||
loss = loss1 + loss2 + loss3
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
list(enc1.parameters()) + list(enc2.parameters()) + list(enc3.parameters()) +
|
||||
list(head1.parameters()) + list(head2.parameters()) + list(head3.parameters()),
|
||||
max_norm=1.0,
|
||||
)
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
ch_losses[0] += loss1.item()
|
||||
ch_losses[1] += loss2.item()
|
||||
ch_losses[2] += loss3.item()
|
||||
|
||||
n_batches = len(dl_train)
|
||||
print(
|
||||
f"[{epoch + 1:03d}] total={total_loss / n_batches:.4f} "
|
||||
f"ch1={ch_losses[0] / n_batches:.4f} "
|
||||
f"ch2={ch_losses[1] / n_batches:.4f} "
|
||||
f"ch3={ch_losses[2] / n_batches:.4f}"
|
||||
)
|
||||
|
||||
# ---- 检查点 & 验证 ----
|
||||
if (epoch + 1) % args.checkpoint == 0 or epoch + 1 == args.epochs:
|
||||
metrics = validate(enc1, enc2, enc3, head1, head2, head3, dl_val)
|
||||
print(
|
||||
f" 验证 ch1_wall_recall={metrics['ch1_wall_recall']:.3f} "
|
||||
f"ch2_recall={metrics['ch2_recall']} "
|
||||
f"ch3_recall={metrics['ch3_recall']}"
|
||||
)
|
||||
print(
|
||||
f" codebook_entropy ch1={metrics['codebook_entropy'][0]:.3f} "
|
||||
f"ch2={metrics['codebook_entropy'][1]:.3f} "
|
||||
f"ch3={metrics['codebook_entropy'][2]:.3f}"
|
||||
)
|
||||
|
||||
ts = datetime.now().strftime("%m%d-%H%M")
|
||||
ckpt_path = f"result/pretrain_split/split-{epoch + 1}.pth"
|
||||
torch.save({
|
||||
"epoch": epoch + 1,
|
||||
"enc1": enc1.state_dict(),
|
||||
"enc2": enc2.state_dict(),
|
||||
"enc3": enc3.state_dict(),
|
||||
"head1": head1.state_dict(),
|
||||
"head2": head2.state_dict(),
|
||||
"head3": head3.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"metrics": metrics,
|
||||
"ts": ts,
|
||||
}, ckpt_path)
|
||||
print(f" Saved checkpoint: {ckpt_path}")
|
||||
|
||||
# ---- 保存最终编码器权重(供联合训练加载) ----
|
||||
final_path = "result/pretrain_split/split_final.pth"
|
||||
torch.save({
|
||||
"epoch": args.epochs,
|
||||
"enc1": enc1.state_dict(),
|
||||
"enc2": enc2.state_dict(),
|
||||
"enc3": enc3.state_dict(),
|
||||
# 解码头不迁移,不保存
|
||||
}, final_path)
|
||||
print(f"\n预训练完成,编码器权重已保存至: {final_path}")
|
||||
print("接下来运行联合训练(阶段 1 冻结热身):")
|
||||
print(f" python -m ginka.train_vq --pretrain_split {final_path} --freeze_vq True")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
@ -45,16 +45,24 @@ MAP_H = MAP_W = 13
|
||||
FOCAL_GAMMA = 2.0 # focal loss 聚焦参数(越大越关注难例/稀有类别)
|
||||
WALL_MASK_RATIO = 0.8
|
||||
|
||||
# VQ-VAE 超参
|
||||
VQ_L = 2 # summary token 数量(即 z 的序列长度)
|
||||
VQ_K = 8 # codebook 大小
|
||||
VQ_D_Z = 128 # codebook 嵌入维度
|
||||
VQ_D_MODEL= 192
|
||||
VQ_NHEAD = 8
|
||||
VQ_LAYERS = 4
|
||||
VQ_DIM_FF = 512
|
||||
VQ_BETA = 0.5 # commit loss 权重
|
||||
VQ_GAMMA = 0.0 # entropy loss 权重
|
||||
# VQ-VAE 公共超参(三路编码器共用,方案 B 三通道分拆)
|
||||
VQ_L = 2 # 每路码字序列长度(三路合计 L1+L2+L3 = 6)
|
||||
VQ_K = 16 # codebook 大小
|
||||
VQ_D_Z = 64 # codebook 嵌入维度(三路保持一致,便于拼接)
|
||||
VQ_BETA = 0.25 # commit loss 权重
|
||||
VQ_GAMMA = 0.1 # entropy loss 权重
|
||||
|
||||
# 各通道编码器配置
|
||||
CH1_D_MODEL = 128; CH1_NHEAD = 4 # 通道 1:空间骨架(floor+wall)
|
||||
CH2_D_MODEL = 64; CH2_NHEAD = 4 # 通道 2:关卡门控
|
||||
CH3_D_MODEL = 64; CH3_NHEAD = 4 # 通道 3:收集资源
|
||||
VQ_LAYERS = 2
|
||||
VQ_DIM_FF = 256
|
||||
|
||||
# 通道专属损失计算范围(用于监控验证召回率)
|
||||
CH1_LOSS = {1}
|
||||
CH2_LOSS = {2, 9, 10}
|
||||
CH3_LOSS = {3, 4, 5, 6, 7, 8}
|
||||
|
||||
# MaskGIT 超参
|
||||
MG_D_MODEL = 256
|
||||
@ -102,9 +110,12 @@ def parse_arguments():
|
||||
parser.add_argument("--checkpoint", type=int, default=5,
|
||||
help="每隔多少 epoch 保存检查点并验证")
|
||||
parser.add_argument("--load_optim", type=bool, default=True)
|
||||
parser.add_argument("--freeze_vq", type=bool, default=False,
|
||||
help="(方案 D 阶段 1)冻结 VQ 编码器,仅训练 MaskGIT。"
|
||||
parser.add_argument("--freeze_vq", type=bool, default=False,
|
||||
help="(方案 B 阶段 1)冻结三路 VQ 编码器,仅训练 MaskGIT。"
|
||||
"适用于预训练权重加载后的热身阶段。")
|
||||
parser.add_argument("--pretrain_split", type=str, default="",
|
||||
help="(方案 B)三通道分拆预训练检查点路径;"
|
||||
"指定后将从该检查点加载三路编码器初始权重。")
|
||||
return parser.parse_args()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -352,7 +363,9 @@ def make_random_struct_cond() -> torch.Tensor:
|
||||
|
||||
@torch.no_grad()
|
||||
def validate(
|
||||
model_vq: GinkaVQVAE,
|
||||
enc1: GinkaVQVAE,
|
||||
enc2: GinkaVQVAE,
|
||||
enc3: GinkaVQVAE,
|
||||
model_mg: GinkaMaskGIT,
|
||||
dataloader_val: DataLoader,
|
||||
tile_dict: dict,
|
||||
@ -373,7 +386,8 @@ def validate(
|
||||
场景5 (scene5_random) : 无数据集参照,随机稀疏墙壁种子 → 完全随机生成
|
||||
列: random seed | z_rand×(N+1)
|
||||
"""
|
||||
model_vq.eval()
|
||||
for enc in [enc1, enc2, enc3]:
|
||||
enc.eval()
|
||||
model_mg.eval()
|
||||
|
||||
# 按 epoch 建立独立子文件夹,保留每次验证结果方便回溯
|
||||
@ -385,14 +399,33 @@ def validate(
|
||||
captured = {s: None for s in ('A', 'B', 'C', 'D')}
|
||||
|
||||
# ── 计算 val loss + 捕获各子集样本 ──────────────────────────────────────
|
||||
def _encode_three(s1, s2, s3):
|
||||
"""三路编码并拼接 z_q。"""
|
||||
z_q1, _, _, vq1, _, _ = enc1(s1)
|
||||
z_q2, _, _, vq2, _, _ = enc2(s2)
|
||||
z_q3, _, _, vq3, _, _ = enc3(s3)
|
||||
z_q = torch.cat([z_q1, z_q2, z_q3], dim=1) # [B, L1+L2+L3, d_z]
|
||||
vq_loss = vq1 + vq2 + vq3
|
||||
return z_q, vq_loss
|
||||
|
||||
def _sample_three(B_size):
|
||||
"""三路随机采样并拼接 z。"""
|
||||
z1 = enc1.sample(B_size, device)
|
||||
z2 = enc2.sample(B_size, device)
|
||||
z3 = enc3.sample(B_size, device)
|
||||
return torch.cat([z1, z2, z3], dim=1)
|
||||
|
||||
for batch in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm):
|
||||
raw_map = batch["raw_map"].to(device) # [B, 169]
|
||||
masked_map = batch["masked_map"].to(device) # [B, 169]
|
||||
target_map = batch["target_map"].to(device) # [B, 169]
|
||||
s1 = batch["slice1"].to(device)
|
||||
s2 = batch["slice2"].to(device)
|
||||
s3 = batch["slice3"].to(device)
|
||||
subsets = batch["subset"] # list of str
|
||||
B = raw_map.shape[0]
|
||||
|
||||
z_q, _, _, vq_loss, _, _ = model_vq(raw_map)
|
||||
z_q, vq_loss = _encode_three(s1, s2, s3)
|
||||
struct_cond_b = batch["struct_cond"].to(device) # [B, 4]
|
||||
logits = model_mg(masked_map, z_q, struct_cond=struct_cond_b)
|
||||
mask = (masked_map == MASK_TOKEN)
|
||||
@ -419,7 +452,7 @@ def validate(
|
||||
def _rand_gens(cond_map, n):
|
||||
imgs = []
|
||||
for i in range(n):
|
||||
z_r = model_vq.sample(1, device)
|
||||
z_r = _sample_three(1)
|
||||
gen = maskgit_generate(model_mg, z_r, init_map=cond_map) # struct_cond=None 无条件
|
||||
imgs.append(label_image(make_map_image(gen[0], tile_dict), f"z_rand_{i + 1}"))
|
||||
return imgs
|
||||
@ -428,7 +461,7 @@ def validate(
|
||||
def _rand_gens_with_struct(cond_map, n):
|
||||
imgs = []
|
||||
for i in range(n):
|
||||
z_r = model_vq.sample(1, device)
|
||||
z_r = _sample_three(1)
|
||||
sc_r = make_random_struct_cond() # [1, 4] 随机合法标签
|
||||
gen = maskgit_generate(model_mg, z_r, init_map=cond_map, struct_cond=sc_r)
|
||||
img = label_image(make_map_image(gen[0], tile_dict), f"z_rand_{i + 1}")
|
||||
@ -539,15 +572,15 @@ def train():
|
||||
print(f"Using device: {device}")
|
||||
args = parse_arguments()
|
||||
|
||||
# ---- 模型 ----
|
||||
model_vq = GinkaVQVAE(
|
||||
num_classes=NUM_CLASSES,
|
||||
L=VQ_L, K=VQ_K, d_z=VQ_D_Z,
|
||||
d_model=VQ_D_MODEL, nhead=VQ_NHEAD,
|
||||
num_layers=VQ_LAYERS, dim_ff=VQ_DIM_FF,
|
||||
map_size=MAP_SIZE,
|
||||
# ---- 三路编码器(方案 B 三通道分拆) ----
|
||||
_vq_common = dict(
|
||||
num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_z=VQ_D_Z,
|
||||
num_layers=VQ_LAYERS, dim_ff=VQ_DIM_FF, map_size=MAP_SIZE,
|
||||
beta=VQ_BETA, gamma=VQ_GAMMA,
|
||||
).to(device)
|
||||
)
|
||||
enc1 = GinkaVQVAE(d_model=CH1_D_MODEL, nhead=CH1_NHEAD, **_vq_common).to(device)
|
||||
enc2 = GinkaVQVAE(d_model=CH2_D_MODEL, nhead=CH2_NHEAD, **_vq_common).to(device)
|
||||
enc3 = GinkaVQVAE(d_model=CH3_D_MODEL, nhead=CH3_NHEAD, **_vq_common).to(device)
|
||||
|
||||
model_mg = GinkaMaskGIT(
|
||||
num_classes=NUM_CLASSES,
|
||||
@ -559,11 +592,11 @@ def train():
|
||||
struct_dropout=MG_STRUCT_DROPOUT,
|
||||
).to(device)
|
||||
|
||||
vq_params = sum(p.numel() for p in model_vq.parameters())
|
||||
enc_params = sum(p.numel() for m in [enc1, enc2, enc3] for p in m.parameters())
|
||||
mg_params = sum(p.numel() for p in model_mg.parameters())
|
||||
print(f"VQ-VAE 参数量: {vq_params:,} ({vq_params/1e6:.3f}M)")
|
||||
print(f"MaskGIT 参数量: {mg_params:,} ({mg_params/1e6:.3f}M)")
|
||||
print(f"Total 参数量: {vq_params+mg_params:,} ({(vq_params+mg_params)/1e6:.3f}M)")
|
||||
print(f"Encoders 参数量(三路): {enc_params:,} ({enc_params/1e6:.3f}M)")
|
||||
print(f"MaskGIT 参数量: {mg_params:,} ({mg_params/1e6:.3f}M)")
|
||||
print(f"Total 参数量: {enc_params+mg_params:,} ({(enc_params+mg_params)/1e6:.3f}M)")
|
||||
|
||||
# ---- 数据集 ----
|
||||
dataset_train = GinkaVQDataset(
|
||||
@ -587,19 +620,29 @@ def train():
|
||||
num_workers=0,
|
||||
)
|
||||
|
||||
# ---- 优化器(联合训练,两个模型共用一个 optimizer)----
|
||||
all_params = list(model_vq.parameters()) + list(model_mg.parameters())
|
||||
# ---- 优化器(联合训练,三路编码器 + MaskGIT 共用)----
|
||||
enc_params_list = list(enc1.parameters()) + list(enc2.parameters()) + list(enc3.parameters())
|
||||
all_params = enc_params_list + list(model_mg.parameters())
|
||||
optimizer = optim.AdamW(all_params, lr=2e-4, weight_decay=1e-2)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer, T_max=args.epochs, eta_min=1e-6
|
||||
)
|
||||
|
||||
# ---- 续训 ----
|
||||
# ---- 权重加载 ----
|
||||
start_epoch = 0
|
||||
if args.resume:
|
||||
if args.pretrain_split:
|
||||
# 从分拆预训练检查点加载三路编码器初始权重(阶段 1 冻结热身前)
|
||||
ckpt = torch.load(args.pretrain_split, map_location=device)
|
||||
enc1.load_state_dict(ckpt["enc1"])
|
||||
enc2.load_state_dict(ckpt["enc2"])
|
||||
enc3.load_state_dict(ckpt["enc3"])
|
||||
print(f"已加载分拆预训练编码器权重: {args.pretrain_split}")
|
||||
elif args.resume:
|
||||
ckpt = torch.load(args.state, map_location=device)
|
||||
model_vq.load_state_dict(ckpt["vq_state"], strict=False)
|
||||
model_mg.load_state_dict(ckpt["mg_state"], strict=False)
|
||||
enc1.load_state_dict(ckpt["enc1"], strict=False)
|
||||
enc2.load_state_dict(ckpt["enc2"], strict=False)
|
||||
enc3.load_state_dict(ckpt["enc3"], strict=False)
|
||||
model_mg.load_state_dict(ckpt["mg_state"], strict=False)
|
||||
if args.load_optim and ckpt.get("optim_state") is not None:
|
||||
optimizer.load_state_dict(ckpt["optim_state"])
|
||||
start_epoch = ckpt.get("epoch", 0)
|
||||
@ -613,16 +656,18 @@ def train():
|
||||
if img is not None:
|
||||
tile_dict[name] = img
|
||||
|
||||
# ---- 方案 D 阶段 1:冻结 VQ 编码器 ----
|
||||
# ---- 方案 B 阶段 1:冻结三路 VQ 编码器 ----
|
||||
if args.freeze_vq:
|
||||
for p in model_vq.parameters():
|
||||
p.requires_grad_(False)
|
||||
print("VQ 编码器已冻结(方案 D 阶段 1:MaskGIT 热身)。")
|
||||
for enc in [enc1, enc2, enc3]:
|
||||
for p in enc.parameters():
|
||||
p.requires_grad_(False)
|
||||
print("三路 VQ 编码器已冻结(阶段 1:MaskGIT 热身)。")
|
||||
|
||||
# ---- 训练循环 ----
|
||||
for epoch in tqdm(range(start_epoch, start_epoch + args.epochs),
|
||||
desc="Joint Training", disable=disable_tqdm):
|
||||
model_vq.train()
|
||||
for enc in [enc1, enc2, enc3]:
|
||||
enc.train()
|
||||
model_mg.train()
|
||||
|
||||
loss_total = 0.0
|
||||
@ -638,13 +683,23 @@ def train():
|
||||
raw_map = batch["raw_map"].to(device) # [B, 169]
|
||||
masked_map = batch["masked_map"].to(device) # [B, 169]
|
||||
target_map = batch["target_map"].to(device) # [B, 169]
|
||||
s1 = batch["slice1"].to(device) # 通道 1 切片
|
||||
s2 = batch["slice2"].to(device) # 通道 2 切片
|
||||
s3 = batch["slice3"].to(device) # 通道 3 切片
|
||||
|
||||
for s in batch["subset"]:
|
||||
subset_stats[s] = subset_stats.get(s, 0) + 1
|
||||
|
||||
# ---- 前向传播 ----
|
||||
# 1. VQ-VAE 编码真实地图 → z_q, z_e
|
||||
z_q, z_e, _, vq_loss, commit_loss, entropy_loss = model_vq(raw_map) # z_q/z_e: [B, L, d_z]
|
||||
# 1. 三路 VQ 编码器各自编码对应切片 → 拼接 z
|
||||
z_q1, z_e1, _, vq_loss1, commit_loss1, entropy_loss1 = enc1(s1)
|
||||
z_q2, z_e2, _, vq_loss2, commit_loss2, entropy_loss2 = enc2(s2)
|
||||
z_q3, z_e3, _, vq_loss3, commit_loss3, entropy_loss3 = enc3(s3)
|
||||
z_q = torch.cat([z_q1, z_q2, z_q3], dim=1) # [B, L1+L2+L3, d_z]
|
||||
z_e = torch.cat([z_e1, z_e2, z_e3], dim=1) # [B, L1+L2+L3, d_z]
|
||||
vq_loss = vq_loss1 + vq_loss2 + vq_loss3
|
||||
commit_loss = commit_loss1 + commit_loss2 + commit_loss3
|
||||
entropy_loss = entropy_loss1 + entropy_loss2 + entropy_loss3
|
||||
|
||||
# 2. MaskGIT 以掩码地图 + z + 结构标签预测原始 tile
|
||||
struct_cond = batch["struct_cond"].to(device) # [B, 4]
|
||||
@ -655,23 +710,25 @@ def train():
|
||||
ce_loss = focal_loss(logits.permute(0, 2, 1), target_map)
|
||||
masked_ce = (ce_loss * mask).sum() / (mask.sum() + 1e-6)
|
||||
|
||||
# 4. z 一致性约束(方案 A):将 MaskGIT 的 logits 经温度平滑后
|
||||
# 与 VQ 编码器的 tile embedding 做加权求和,得到软嵌入序列,
|
||||
# 再送入编码器得到 z_pred_e,约束其与真实 z_e 对齐。
|
||||
# 梯度从 z_pred_e 回传到 MaskGIT 的 logits;
|
||||
# VQ 参数在此路径上临时冻结(requires_grad=False),
|
||||
# 确保编码器权重仅由真实地图路径(vq_loss)更新,不被一致性损失带偏。
|
||||
for p in model_vq.parameters():
|
||||
p.requires_grad_(False)
|
||||
# 4. z 一致性约束(方案 A 扩展到三通道):
|
||||
# MaskGIT logits 经温度平滑后与各编码器的 tile embedding 做加权求和,
|
||||
# 得到软嵌入 → 各编码器再次编码 → z_pred_e_k 与真实 z_e_k 对齐。
|
||||
# 编码器权重在此路径上临时冻结,确保梯度仅回传至 MaskGIT。
|
||||
for enc in [enc1, enc2, enc3]:
|
||||
for p in enc.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
soft_probs = F.softmax(logits / CONSIST_TEMP, dim=-1) # [B, H*W, V]
|
||||
tile_emb = model_vq.tile_embedding.weight # [V, d_model]
|
||||
soft_emb = soft_probs @ tile_emb # [B, H*W, d_model]
|
||||
z_pred_e = model_vq.encode_soft(soft_emb) # [B, L, d_z]
|
||||
soft_probs = F.softmax(logits / CONSIST_TEMP, dim=-1) # [B, H*W, V]
|
||||
z_pred_e1 = enc1.encode_soft(soft_probs @ enc1.tile_embedding.weight)
|
||||
z_pred_e2 = enc2.encode_soft(soft_probs @ enc2.tile_embedding.weight)
|
||||
z_pred_e3 = enc3.encode_soft(soft_probs @ enc3.tile_embedding.weight)
|
||||
z_pred_e = torch.cat([z_pred_e1, z_pred_e2, z_pred_e3], dim=1)
|
||||
consist_loss = F.mse_loss(z_pred_e, z_e.detach())
|
||||
|
||||
for p in model_vq.parameters():
|
||||
p.requires_grad_(True)
|
||||
|
||||
if not args.freeze_vq:
|
||||
for enc in [enc1, enc2, enc3]:
|
||||
for p in enc.parameters():
|
||||
p.requires_grad_(True)
|
||||
|
||||
# 5. 联合损失
|
||||
loss = masked_ce + vq_loss + CONSIST_LAMBDA * consist_loss
|
||||
@ -709,26 +766,31 @@ def train():
|
||||
ckpt_path = f"result/joint/joint-{epoch + 1}.pth"
|
||||
torch.save({
|
||||
"epoch": epoch + 1,
|
||||
"vq_state": model_vq.state_dict(),
|
||||
"enc1": enc1.state_dict(),
|
||||
"enc2": enc2.state_dict(),
|
||||
"enc3": enc3.state_dict(),
|
||||
"mg_state": model_mg.state_dict(),
|
||||
"optim_state":optimizer.state_dict(),
|
||||
}, ckpt_path)
|
||||
tqdm.write(f" 检查点已保存: {ckpt_path}")
|
||||
|
||||
val_loss = validate(
|
||||
model_vq, model_mg, dataloader_val, tile_dict, epoch + 1
|
||||
enc1, enc2, enc3, model_mg, dataloader_val, tile_dict, epoch + 1
|
||||
)
|
||||
tqdm.write(
|
||||
f"[Validate] Epoch {epoch + 1:4d} | Val Loss {val_loss:.5f}"
|
||||
)
|
||||
# 恢复训练模式
|
||||
model_vq.train()
|
||||
for enc in [enc1, enc2, enc3]:
|
||||
enc.train()
|
||||
model_mg.train()
|
||||
|
||||
print("训练结束。")
|
||||
torch.save({
|
||||
"epoch": start_epoch + args.epochs,
|
||||
"vq_state": model_vq.state_dict(),
|
||||
"enc1": enc1.state_dict(),
|
||||
"enc2": enc2.state_dict(),
|
||||
"enc3": enc3.state_dict(),
|
||||
"mg_state": model_mg.state_dict(),
|
||||
}, "result/joint/joint_final.pth")
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
def print_memory(device, tag=""):
|
||||
@ -31,4 +32,46 @@ def nms_sampling(noise: np.ndarray, k: int, radius=2):
|
||||
result[y, x] = 1
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
def masked_focal(
|
||||
logits: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
tile_set: set,
|
||||
gamma: float = 2.0,
|
||||
eps: float = 1e-6,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
通道专属掩码 Focal Loss:仅在 tile_set 中指定的 tile 位置计算损失。
|
||||
|
||||
Args:
|
||||
logits: [B, H*W, num_classes] 解码头输出(未经 softmax)
|
||||
target: [B, H*W] 完整地图 ground truth(整数 tile ID)
|
||||
tile_set: set of int 本通道专属 tile 集合,其余位置损失权重为 0
|
||||
gamma: Focal Loss 聚焦参数
|
||||
eps: 数值稳定的分母偏置
|
||||
|
||||
Returns:
|
||||
scalar tensor 通道专属掩码 Focal Loss
|
||||
"""
|
||||
B, S, C = logits.shape
|
||||
|
||||
# 构造掩码:仅在专属 tile 位置为 True
|
||||
mask = torch.zeros(B, S, dtype=torch.bool, device=logits.device)
|
||||
for t in tile_set:
|
||||
mask |= (target == t)
|
||||
|
||||
if not mask.any():
|
||||
return logits.sum() * 0.0 # 保留计算图,返回零梯度
|
||||
|
||||
# Focal Loss(reduction='none')
|
||||
ce = F.cross_entropy(
|
||||
logits.view(-1, C),
|
||||
target.view(-1),
|
||||
reduction='none',
|
||||
).view(B, S) # [B, S]
|
||||
|
||||
pt = torch.exp(-ce.detach()) # 正确类预测概率,stop-gradient
|
||||
fl = (1.0 - pt) ** gamma * ce
|
||||
|
||||
return (fl * mask).sum() / (mask.sum() + eps)
|
||||
|
||||
@ -1,10 +1,14 @@
|
||||
#!/usr/bin/env bash
|
||||
# ==============================================================================
|
||||
# 三阶段完整训练流水线
|
||||
# 三阶段完整训练流水线(方案 B:三通道分拆 VQ 编码器)
|
||||
#
|
||||
# 阶段 0 VQ 编码器预训练 train_pretrain.py
|
||||
# 阶段 1 MaskGIT 热身 train_vq.py --freeze_vq True
|
||||
# 阶段 0 三通道分拆预训练 train_pretrain_split.py
|
||||
# enc1(floor+wall) / enc2(+door+mob+entrance) / enc3(全图)
|
||||
# 各自仅对本通道 tile 计算 masked Focal Loss
|
||||
# 阶段 1 MaskGIT 热身 train_vq.py --pretrain_split --freeze_vq True
|
||||
# 三路编码器权重冻结,仅训练 MaskGIT
|
||||
# 阶段 2 完整联合训练 train_vq.py
|
||||
# 三路编码器 + MaskGIT 全量优化
|
||||
#
|
||||
# 用法:
|
||||
# bash train_full.sh # 从头开始三阶段训练
|
||||
@ -19,10 +23,10 @@ set -euo pipefail
|
||||
TRAIN_DATA="ginka-dataset.json"
|
||||
EVAL_DATA="ginka-eval.json"
|
||||
|
||||
# 阶段 0:预训练
|
||||
# 阶段 0:三通道分拆预训练
|
||||
P0_EPOCHS=100
|
||||
P0_CHECKPOINT=10
|
||||
P0_FINAL="result/pretrain/pretrain_final.pth"
|
||||
P0_FINAL="result/pretrain_split/split_final.pth"
|
||||
|
||||
# 阶段 1:冻结编码器热身
|
||||
P1_EPOCHS=50
|
||||
@ -68,14 +72,15 @@ die() {
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 阶段 0:VQ 编码器预训练
|
||||
# 阶段 0:三通道分拆 VQ 编码器预训练
|
||||
# ------------------------------------------------------------------------------
|
||||
if [[ $START_PHASE -le 0 ]]; then
|
||||
log "阶段 0 / 3 VQ 编码器预训练 (epochs=${P0_EPOCHS})"
|
||||
python3 -u -m ginka.train_pretrain \
|
||||
--train "$TRAIN_DATA" \
|
||||
--validate "$EVAL_DATA" \
|
||||
--epochs "$P0_EPOCHS" \
|
||||
log "阶段 0 / 3 三通道分拆 VQ 预训练 (epochs=${P0_EPOCHS})"
|
||||
mkdir -p result/pretrain_split
|
||||
python3 -u -m ginka.train_pretrain_split \
|
||||
--train "$TRAIN_DATA" \
|
||||
--validate "$EVAL_DATA" \
|
||||
--epochs "$P0_EPOCHS" \
|
||||
--checkpoint "$P0_CHECKPOINT"
|
||||
|
||||
[[ -f "$P0_FINAL" ]] || die "阶段 0 未生成预期检查点:$P0_FINAL"
|
||||
@ -86,24 +91,23 @@ else
|
||||
fi
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 阶段 1:MaskGIT 热身(VQ 编码器冻结)
|
||||
# 阶段 1:MaskGIT 热身(三路 VQ 编码器冻结)
|
||||
# ------------------------------------------------------------------------------
|
||||
if [[ $START_PHASE -le 1 ]]; then
|
||||
log "阶段 1 / 3 MaskGIT 热身(VQ 冻结) (epochs=${P1_EPOCHS})"
|
||||
log "阶段 1 / 3 MaskGIT 热身(三路 VQ 冻结) (epochs=${P1_EPOCHS})"
|
||||
mkdir -p result/joint
|
||||
python3 -u -m ginka.train_vq \
|
||||
--train "$TRAIN_DATA" \
|
||||
--validate "$EVAL_DATA" \
|
||||
--resume True \
|
||||
--state "$P0_FINAL" \
|
||||
--load_optim False \
|
||||
--freeze_vq True \
|
||||
--epochs "$P1_EPOCHS" \
|
||||
--checkpoint "$P1_CHECKPOINT"
|
||||
--train "$TRAIN_DATA" \
|
||||
--validate "$EVAL_DATA" \
|
||||
--pretrain_split "$P0_FINAL" \
|
||||
--load_optim False \
|
||||
--freeze_vq True \
|
||||
--epochs "$P1_EPOCHS" \
|
||||
--checkpoint "$P1_CHECKPOINT"
|
||||
|
||||
# 阶段 1 最后一个检查点
|
||||
# 取阶段 1 最后一个检查点,固定为阶段 2 入口
|
||||
_P1_LAST=$(ls -t result/joint/joint-*.pth 2>/dev/null | head -1)
|
||||
[[ -n "$_P1_LAST" ]] || die "阶段 1 未生成任何检查点(result/joint/joint-*.pth)"
|
||||
# 复制为阶段 1 固定终态,供阶段 2 加载
|
||||
cp "$_P1_LAST" "$P1_FINAL"
|
||||
log "阶段 1 完成 → $P1_FINAL(来自 $_P1_LAST)"
|
||||
else
|
||||
@ -112,18 +116,18 @@ else
|
||||
fi
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 阶段 2:完整联合训练
|
||||
# 阶段 2:完整联合训练(三路编码器 + MaskGIT 全量)
|
||||
# ------------------------------------------------------------------------------
|
||||
if [[ $START_PHASE -le 2 ]]; then
|
||||
log "阶段 2 / 3 完整联合训练 (epochs=${P2_EPOCHS})"
|
||||
python3 -u -m ginka.train_vq \
|
||||
--train "$TRAIN_DATA" \
|
||||
--validate "$EVAL_DATA" \
|
||||
--resume True \
|
||||
--state "$P1_FINAL" \
|
||||
--load_optim False \
|
||||
--freeze_vq False \
|
||||
--epochs "$P2_EPOCHS" \
|
||||
--train "$TRAIN_DATA" \
|
||||
--validate "$EVAL_DATA" \
|
||||
--resume True \
|
||||
--state "$P1_FINAL" \
|
||||
--load_optim False \
|
||||
--freeze_vq False \
|
||||
--epochs "$P2_EPOCHS" \
|
||||
--checkpoint "$P2_CHECKPOINT"
|
||||
|
||||
log "阶段 2 完成"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user