feat: vq 预训练

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
unanmed 2026-04-28 16:31:53 +08:00
parent 789107969b
commit a69403d6bf
5 changed files with 607 additions and 1 deletions

View File

@ -8,9 +8,10 @@
| 方案 | 核心思路 | 状态 |
| ------ | -------------------------------------------------- | -------- |
| 方案 A | 重建一致性约束:将生成结果回送编码器,令 z 闭环 | 待细化 |
| 方案 A | 重建一致性约束:将生成结果回送编码器,令 z 闭环 | 已实施 |
| 方案 B | 多路分拆编码:将地图按层次结构分拆为多部分分别编码 | 待细化 |
| 方案 C | 多阶段生成:先墙壁,再门怪,最后资源 | 后续计划 |
| 方案 D | VQ 编码器预训练:先单独训练编码器学会重建,再联合 | 待细化 |
---
@ -167,6 +168,80 @@ MaskGIT Cross-Attentionz 作为 memory
---
---
## 方案 DVQ 编码器预训练
### 问题诊断
当前联合训练时VQ 编码器和 MaskGIT 从随机初始化开始同步优化。由于编码器尚未学到任何地图语义,早期 z 基本是随机噪声MaskGIT 无法从中获得有效的条件信号,两者的优化信号相互干扰,容易导致训练早期陷入局部最优或收敛缓慢。
解决思路:在联合训练开始前,先单独预训练 VQ 编码器,使其具备初步的地图语义理解能力,再以此为初始化启动联合训练。
### 核心思路
为 VQ-VAE 临时增加一个轻量解码头Decoder Head构成完整的自编码器以完整地图重建为目标进行预训练
$$\mathcal{L}_{pretrain} = \mathcal{L}_{CE}^{recon} + \beta \cdot \mathcal{L}_{commit} + \gamma \cdot \mathcal{L}_{uniform}$$
其中 $\mathcal{L}_{CE}^{recon}$ 是对全部 169 个位置的交叉熵重建损失(不做掩码,全图重建)。预训练完成后,解码头被丢弃,编码器权重作为联合训练的初始化。
### 解码头设计
解码头的职责是将 z_q [B, L, d_z] 还原为 [B, H*W, num_classes],有以下两种设计选项:
#### 选项 D-1Cross-Attention 解码头(推荐)
```
z_q [B, L, d_z]
可学习位置查询 [B, H*W, d_z](每个格子对应一个 query
│ Cross-Attentionquery=位置查询key/value=z_q
线性分类头 → logits [B, H*W, num_classes]
```
与 MaskGIT 的 Cross-Attention 结构高度一致,预训练阶段即可验证"z → 地图"的解码路径是否畅通。解码头参数量小(单层 Cross-Attention + Linear预训练速度快。
#### 选项 D-2简单线性展开基线
```
z_q [B, L, d_z]
│ Flatten → Linear
logits [B, H*W, num_classes]
```
实现最简单,但 L × d_z → H\*W × num_classes 的映射会引入大量参数L=32, d_z=128 时约 512K且缺乏空间归纳偏置效果可能较差。
**推荐选项 D-1**,结构与联合训练阶段的 MaskGIT 解码路径一致,预训练阶段已对"z 作为 Cross-Attention memory 驱动生成"这一机制进行充分热身。
### 训练策略
| 阶段 | 模型状态 | 目标 | 建议轮数 |
| -------------------- | ----------------------------- | ----------------------------------------- | ------------ |
| 阶段 0预训练 | 编码器 + 临时解码头 | 全图重建,$\mathcal{L}_{pretrain}$ 收敛 | 2050 epoch |
| 阶段 1联合热身 | 编码器冻结 + MaskGIT 训练 | 让 MaskGIT 先适应固定的 z 分布 | 2040 epoch |
| 阶段 2完整联合训练 | 全部参数解冻,编码器用较小 LR | 端到端联合优化(可叠加方案 A 一致性约束) | 正常训练轮数 |
> 阶段 1 的编码器冻结热身建议执行若直接解冻联合训练MaskGIT 早期的不稳定梯度可能逐渐覆盖编码器预训练获得的语义。考虑到 MaskGIT 收敛速度相对较慢,热身阶段建议适当延长至 2040 epoch。
### 实现要点
1. **解码头独立模块**:将解码头实现为独立的类(如 `VQDecodeHead`),不修改 `GinkaVQVAE` 的核心结构,预训练结束后直接丢弃,不影响联合训练代码路径。
2. **预训练脚本独立**:新增 `ginka/train_pretrain.py`,与联合训练脚本 `train_vq.py` 分离,便于单独调试。
3. **权重迁移**:预训练结束后通过 `model_vq.load_state_dict(ckpt['vq_state'], strict=False)` 将编码器权重加载到联合训练中。
4. **重建质量指标**:预训练阶段重点监控逐类别准确率(尤其是墙壁 tile=1 的召回率确认编码器已学到基本的空间结构语义。需注意codebook 容量远小于训练集数量,预训练的目标更倾向于让编码器学会地图的大致分类,而非像素级完整重建——重建损失在此主要作为分类学习的约束信号。
### 与其他方案的关系
- 方案 D 是**独立于方案 A/B 的训练流程优化**,不修改模型推理时的计算图,与方案 A 的一致性约束、方案 B 的多路编码均可叠加使用。
- 方案 D 完成后,方案 A 的一致性约束的初始条件更好(编码器已具有语义),收敛应更快、更稳定。
- 若最终采用方案 B多路分拆每个通道的编码器均可独立预训练后再联合训练。
---
## 两方案的对比
| 维度 | 方案 Az 闭环) | 方案 B多路分拆 |
@ -182,6 +257,12 @@ MaskGIT Cross-Attentionz 作为 memory
## 实施建议
### 阶段零:预训练编码器(方案 D可选但推荐
1. 实现 `VQDecodeHead`Cross-Attention 解码头)和独立预训练脚本 `ginka/train_pretrain.py`
2. 以全图重建为目标预训练 VQ 编码器 2050 epoch直至重建准确率尤其是墙壁类趋于稳定
3. 保存编码器权重,作为阶段一联合训练的初始化。
### 阶段一:验证方案 A低风险快速验证
1. 在现有联合训练代码中,对子集 A 的训练步骤增加软分布近似一致性损失;
@ -206,6 +287,8 @@ MaskGIT Cross-Attentionz 作为 memory
## 待细化事项
- [x] 方案 A一致性损失的权重 $\lambda$ 如何随训练进度调度?→ 先使用常量(初始值 0.1),效果不佳再引入调度策略。
- [x] 方案 D预训练阶段是否对所有子集数据都进行预训练还是只用完整地图→ 仅使用完整地图raw_map。子集划分的差异体现在输入条件上但输出目标始终是完整地图预训练阶段无需区分子集。
- [x] 方案 D预训练完成后联合训练时编码器是否需要冻结热身阶段→ 建议执行冻结热身。若直接解冻联合训练MaskGIT 的不稳定梯度可能逐渐覆盖编码器预训练所获得的语义;考虑到 MaskGIT 收敛较慢,热身 epoch 数适当增大(建议 2040 epoch
- [x] 方案 A单步解码还是多步解码后计算一致性损失→ 训练时 MaskGIT 只进行单步解码,直接在单步结果上计算,无需多步展开。
- [x] 方案 B通道 2 的"墙壁"是否需要保留,还是只保留入口 + 怪 + 门?→ 保留墙壁。去掉墙壁后剩余内容趋向于散点,缺乏空间结构指导意义。
- [x] 方案 B三路 z 拼接后总长度是否超出 MaskGIT cross-attention 的合理 memory 长度?→ 先直接拼接,如有性能问题再评估截断或压缩策略。

318
ginka/train_pretrain.py Normal file
View File

@ -0,0 +1,318 @@
"""
VQ 编码器预训练脚本方案 D
目标在联合训练开始前先单独预训练 VQ 编码器使其学到地图的大致语义分类
解码头VQDecodeHead仅在预训练阶段使用结束后丢弃权重不迁移到联合训练
训练流程对应设计文档方案 D 三阶段
阶段 0本脚本编码器 + 临时解码头全图重建目标
阶段 1 train_vq.py 编码器冻结 + MaskGIT 热身启用 --freeze_vq
阶段 2 train_vq.py 完整联合训练编码器用较小 LR
用法示例
python -m ginka.train_pretrain
python -m ginka.train_pretrain --resume True --state result/pretrain/pretrain-20.pth
# 预训练完成后,传入权重路径启动联合训练阶段 1
python -m ginka.train_vq --resume True --state result/pretrain/pretrain_final.pth
"""
import argparse
import os
import sys
from datetime import datetime
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from .vqvae.model import GinkaVQVAE, VQDecodeHead
from .dataset import load_data
# ---------------------------------------------------------------------------
# 超参数(须与 train_vq.py 中 VQ-VAE 配置保持一致)
# ---------------------------------------------------------------------------
BATCH_SIZE = 64
NUM_CLASSES = 16
MAP_SIZE = 13 * 13
MAP_H = MAP_W = 13
# VQ-VAE 超参(保持与 train_vq.py 一致)
VQ_L = 32
VQ_K = 1
VQ_D_Z = 128
VQ_D_MODEL= 192
VQ_NHEAD = 8
VQ_LAYERS = 4
VQ_DIM_FF = 512
VQ_BETA = 0.5
VQ_GAMMA = 0.0
# 解码头超参
DH_NHEAD = 8 # Cross-Attention 头数VQ_D_Z=128 可被 8 整除)
# ---------------------------------------------------------------------------
# 设备
# ---------------------------------------------------------------------------
device = torch.device(
"cuda:1" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)
os.makedirs("result/pretrain", exist_ok=True)
disable_tqdm = not sys.stdout.isatty()
# ---------------------------------------------------------------------------
# 简单数据集:仅返回 raw_map无子集划分无掩码
# ---------------------------------------------------------------------------
class GinkaPretrainDataset(Dataset):
"""
预训练专用数据集仅提供完整原始地图raw_map和随机数据增强
不做子集划分与掩码处理重建目标为全图所有 169 个位置
"""
def __init__(self, data_path: str):
self.data = load_data(data_path)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
arr = np.array(item['map'], dtype=np.int64) # [H, W]
# 随机旋转 / 翻转数据增强
if np.random.rand() > 0.5:
k = np.random.randint(1, 4)
arr = np.rot90(arr, k).copy()
if np.random.rand() > 0.5:
arr = np.fliplr(arr).copy()
if np.random.rand() > 0.5:
arr = np.flipud(arr).copy()
raw_map = torch.tensor(arr.reshape(-1), dtype=torch.long) # [H*W]
return raw_map
# ---------------------------------------------------------------------------
# 参数解析
# ---------------------------------------------------------------------------
def parse_arguments():
parser = argparse.ArgumentParser(description="VQ 编码器预训练(方案 D")
parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--state", type=str, default="result/pretrain/pretrain-20.pth",
help="续训时加载的检查点路径")
parser.add_argument("--train", type=str, default="ginka-dataset.json")
parser.add_argument("--validate", type=str, default="ginka-eval.json")
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--checkpoint", type=int, default=5,
help="每隔多少 epoch 保存检查点并输出验证指标")
parser.add_argument("--load_optim", type=bool, default=True)
return parser.parse_args()
# ---------------------------------------------------------------------------
# 验证:计算全图 top-1 准确率及关键类别(墙壁)召回率
# ---------------------------------------------------------------------------
@torch.no_grad()
def validate(
model_vq: GinkaVQVAE,
decode_head: VQDecodeHead,
dataloader_val: DataLoader,
) -> dict:
model_vq.eval()
decode_head.eval()
total, correct = 0, 0
wall_tp, wall_gt = 0, 0 # wall tile=1 的召回
class_correct = torch.zeros(NUM_CLASSES, dtype=torch.long)
class_total = torch.zeros(NUM_CLASSES, dtype=torch.long)
for raw_map in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm):
raw_map = raw_map.to(device) # [B, H*W]
z_q, _, _, _, _, _ = model_vq(raw_map)
logits = decode_head(z_q) # [B, H*W, C]
pred = logits.argmax(dim=-1) # [B, H*W]
correct += (pred == raw_map).sum().item()
total += raw_map.numel()
# 墙壁召回
wall_mask = (raw_map == 1)
wall_tp += (pred[wall_mask] == 1).sum().item()
wall_gt += wall_mask.sum().item()
# 逐类别统计
for c in range(NUM_CLASSES):
mask_c = (raw_map == c)
class_correct[c] += (pred[mask_c] == c).sum().item()
class_total[c] += mask_c.sum().item()
acc = correct / max(total, 1)
wall_rec = wall_tp / max(wall_gt, 1)
# 有样本的类别逐一统计
per_class = {}
for c in range(NUM_CLASSES):
if class_total[c] > 0:
per_class[c] = class_correct[c].item() / class_total[c].item()
return {"acc": acc, "wall_recall": wall_rec, "per_class": per_class}
# ---------------------------------------------------------------------------
# 主训练函数
# ---------------------------------------------------------------------------
def train():
print(f"Using device: {device}")
args = parse_arguments()
# ---- 模型 ----
model_vq = GinkaVQVAE(
num_classes=NUM_CLASSES,
L=VQ_L, K=VQ_K, d_z=VQ_D_Z,
d_model=VQ_D_MODEL, nhead=VQ_NHEAD,
num_layers=VQ_LAYERS, dim_ff=VQ_DIM_FF,
map_size=MAP_SIZE,
beta=VQ_BETA, gamma=VQ_GAMMA,
).to(device)
decode_head = VQDecodeHead(
num_classes=NUM_CLASSES,
d_z=VQ_D_Z,
map_size=MAP_SIZE,
nhead=DH_NHEAD,
).to(device)
vq_params = sum(p.numel() for p in model_vq.parameters())
dh_params = sum(p.numel() for p in decode_head.parameters())
print(f"VQ-VAE 参数量: {vq_params:,} ({vq_params/1e6:.3f}M)")
print(f"DecodeHead 参数量: {dh_params:,} ({dh_params/1e6:.3f}M)")
# ---- 数据集 ----
dataset_train = GinkaPretrainDataset(args.train)
dataset_val = GinkaPretrainDataset(args.validate)
dataloader_train = DataLoader(
dataset_train, batch_size=BATCH_SIZE, shuffle=True,
num_workers=0, pin_memory=(device.type == "cuda"),
)
dataloader_val = DataLoader(
dataset_val, batch_size=BATCH_SIZE, shuffle=False,
num_workers=0,
)
print(f"训练集: {len(dataset_train)} 条 验证集: {len(dataset_val)}")
# ---- 优化器 ----
all_params = list(model_vq.parameters()) + list(decode_head.parameters())
optimizer = optim.AdamW(all_params, lr=2e-4, weight_decay=1e-2)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=args.epochs, eta_min=1e-6
)
# ---- 续训 ----
start_epoch = 0
if args.resume:
ckpt = torch.load(args.state, map_location=device)
model_vq.load_state_dict(ckpt["vq_state"], strict=False)
if "dh_state" in ckpt:
decode_head.load_state_dict(ckpt["dh_state"], strict=False)
if args.load_optim and ckpt.get("optim_state") is not None:
optimizer.load_state_dict(ckpt["optim_state"])
start_epoch = ckpt.get("epoch", 0)
print(f"从 epoch {start_epoch} 接续训练。")
# ---- 训练循环 ----
for epoch in tqdm(range(start_epoch, start_epoch + args.epochs),
desc="VQ Pretrain", disable=disable_tqdm):
model_vq.train()
decode_head.train()
loss_total = 0.0
ce_total = 0.0
commit_total = 0.0
entropy_total = 0.0
for raw_map in tqdm(dataloader_train, leave=False,
desc="Epoch Progress", disable=disable_tqdm):
raw_map = raw_map.to(device) # [B, H*W]
# 1. 编码
z_q, _, _, vq_loss, commit_loss, entropy_loss = model_vq(raw_map)
# 2. 解码→全图重建
logits = decode_head(z_q) # [B, H*W, C]
ce_loss = F.cross_entropy(
logits.permute(0, 2, 1), raw_map # [B, C, H*W] vs [B, H*W]
)
# 3. 总损失(重建 + VQ 正则)
loss = ce_loss + vq_loss
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0)
optimizer.step()
loss_total += loss.detach().item()
ce_total += ce_loss.detach().item()
commit_total += commit_loss.detach().item()
entropy_total += entropy_loss.detach().item()
scheduler.step()
n = len(dataloader_train)
tqdm.write(
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
f"Epoch {epoch + 1:4d} | "
f"Loss {loss_total/n:.5f} "
f"CE {ce_total/n:.5f} "
f"Commit {commit_total/n:.5f} "
f"Entropy {entropy_total/n:.5f} | "
f"LR {scheduler.get_last_lr()[0]:.6f}"
)
# ---- 检查点 + 验证 ----
if (epoch + 1) % args.checkpoint == 0:
ckpt_path = f"result/pretrain/pretrain-{epoch + 1}.pth"
torch.save({
"epoch": epoch + 1,
"vq_state": model_vq.state_dict(),
"dh_state": decode_head.state_dict(),
"optim_state": optimizer.state_dict(),
}, ckpt_path)
tqdm.write(f" 检查点已保存: {ckpt_path}")
metrics = validate(model_vq, decode_head, dataloader_val)
acc_str = f" [Validate] Acc {metrics['acc']:.4f} Wall Recall {metrics['wall_recall']:.4f}"
# 输出有样本的类别准确率
pc = metrics["per_class"]
detail = " ".join(
f"c{c}={v:.3f}" for c, v in sorted(pc.items()) if v < 1.0
)
if detail:
acc_str += f"\n Per-class: {detail}"
tqdm.write(acc_str)
model_vq.train()
decode_head.train()
# ---- 保存最终 VQ 编码器权重 ----
final_path = "result/pretrain/pretrain_final.pth"
torch.save({
"epoch": start_epoch + args.epochs,
"vq_state": model_vq.state_dict(),
# 不保存解码头:联合训练阶段不需要
}, final_path)
print(f"\n预训练完成。编码器权重已保存至: {final_path}")
print(f"联合训练阶段 1 启动命令(编码器冻结热身):")
print(f" python -m ginka.train_vq --resume True --state {final_path} --freeze_vq True")
# ---------------------------------------------------------------------------
if __name__ == "__main__":
torch.set_num_threads(4)
train()

View File

@ -102,6 +102,9 @@ def parse_arguments():
parser.add_argument("--checkpoint", type=int, default=5,
help="每隔多少 epoch 保存检查点并验证")
parser.add_argument("--load_optim", type=bool, default=True)
parser.add_argument("--freeze_vq", type=bool, default=False,
help="(方案 D 阶段 1冻结 VQ 编码器,仅训练 MaskGIT。"
"适用于预训练权重加载后的热身阶段。")
return parser.parse_args()
# ---------------------------------------------------------------------------
@ -583,6 +586,12 @@ def train():
if img is not None:
tile_dict[name] = img
# ---- 方案 D 阶段 1冻结 VQ 编码器 ----
if args.freeze_vq:
for p in model_vq.parameters():
p.requires_grad_(False)
print("VQ 编码器已冻结(方案 D 阶段 1MaskGIT 热身)。")
# ---- 训练循环 ----
for epoch in tqdm(range(start_epoch, start_epoch + args.epochs),
desc="Joint Training", disable=disable_tqdm):

View File

@ -4,6 +4,66 @@ from .quantize import VectorQuantizer
from typing import Tuple
class VQDecodeHead(nn.Module):
"""
VQ-VAE 预训练用轻量解码头Cross-Attention 架构
z_q [B, L, d_z] 通过 Cross-Attention 还原为地图 logits [B, H*W, num_classes]
预训练结束后此模块被丢弃不影响联合训练路径
架构
可学习位置查询 [B, H*W, d_z]
Cross-Attention (query=位置查询, key/value=z_q)
LayerNorm
线性分类头 logits [B, H*W, num_classes]
"""
def __init__(
self,
num_classes: int,
d_z: int,
map_size: int,
nhead: int = 4,
):
"""
Args:
num_classes: tile 类别数
d_z: z 向量维度须与 GinkaVQVAE d_z 一致
map_size: 地图 token 总数H * W
nhead: Cross-Attention 的注意力头数d_z 须能被 nhead 整除
"""
super().__init__()
# 每个格子一个可学习位置查询
self.pos_queries = nn.Parameter(torch.randn(1, map_size, d_z) * 0.02)
# Cross-Attentionquery=位置查询key/value=z_q
self.cross_attn = nn.MultiheadAttention(
embed_dim=d_z,
num_heads=nhead,
batch_first=True,
dropout=0.0,
)
self.norm = nn.LayerNorm(d_z)
# 最终分类头
self.classifier = nn.Linear(d_z, num_classes)
def forward(self, z_q: torch.Tensor) -> torch.Tensor:
"""
Args:
z_q: [B, L, d_z]
Returns:
logits: [B, map_size, num_classes]
"""
B = z_q.shape[0]
q = self.pos_queries.expand(B, -1, -1) # [B, map_size, d_z]
out, _ = self.cross_attn(q, z_q, z_q) # [B, map_size, d_z]
out = self.norm(out)
return self.classifier(out) # [B, map_size, num_classes]
class GinkaVQVAE(nn.Module):
"""
VQ-VAE 风格地图编码器

136
train_full.sh Normal file
View File

@ -0,0 +1,136 @@
#!/usr/bin/env bash
# ==============================================================================
# 三阶段完整训练流水线
#
# 阶段 0 VQ 编码器预训练 train_pretrain.py
# 阶段 1 MaskGIT 热身 train_vq.py --freeze_vq True
# 阶段 2 完整联合训练 train_vq.py
#
# 用法:
# bash train_full.sh # 从头开始三阶段训练
# bash train_full.sh --skip 1 # 跳过阶段 0从阶段 1 开始
# bash train_full.sh --skip 2 # 跳过阶段 0-1直接阶段 2
# ==============================================================================
set -euo pipefail
# ------------------------------------------------------------------------------
# 超参配置(按需修改)
# ------------------------------------------------------------------------------
TRAIN_DATA="ginka-dataset.json"
EVAL_DATA="ginka-eval.json"
# 阶段 0预训练
P0_EPOCHS=50
P0_CHECKPOINT=5
P0_FINAL="result/pretrain/pretrain_final.pth"
# 阶段 1冻结编码器热身
P1_EPOCHS=30
P1_CHECKPOINT=5
P1_FINAL="result/joint/warmup_final.pth"
# 阶段 2完整联合训练
P2_EPOCHS=400
P2_CHECKPOINT=20
# 从哪个阶段开始0 = 从头);命令行 --skip N 可覆盖此值
START_PHASE=0
# ------------------------------------------------------------------------------
# 解析命令行参数
# ------------------------------------------------------------------------------
while [[ $# -gt 0 ]]; do
case "$1" in
--skip)
START_PHASE="$2"
shift 2
;;
*)
echo "未知参数: $1"; exit 1
;;
esac
done
# ------------------------------------------------------------------------------
# 工具函数
# ------------------------------------------------------------------------------
log() {
echo ""
echo "════════════════════════════════════════════════════════════════"
echo " $*"
echo " $(date '+%Y-%m-%d %H:%M:%S')"
echo "════════════════════════════════════════════════════════════════"
}
die() {
echo "[ERROR] $*" >&2
exit 1
}
# ------------------------------------------------------------------------------
# 阶段 0VQ 编码器预训练
# ------------------------------------------------------------------------------
if [[ $START_PHASE -le 0 ]]; then
log "阶段 0 / 3 VQ 编码器预训练 (epochs=${P0_EPOCHS})"
python -m ginka.train_pretrain \
--train "$TRAIN_DATA" \
--validate "$EVAL_DATA" \
--epochs "$P0_EPOCHS" \
--checkpoint "$P0_CHECKPOINT"
[[ -f "$P0_FINAL" ]] || die "阶段 0 未生成预期检查点:$P0_FINAL"
log "阶段 0 完成 → $P0_FINAL"
else
[[ -f "$P0_FINAL" ]] || die "跳过阶段 0 但找不到检查点:$P0_FINAL"
log "阶段 0 已跳过(使用现有检查点 $P0_FINAL"
fi
# ------------------------------------------------------------------------------
# 阶段 1MaskGIT 热身VQ 编码器冻结)
# ------------------------------------------------------------------------------
if [[ $START_PHASE -le 1 ]]; then
log "阶段 1 / 3 MaskGIT 热身VQ 冻结) (epochs=${P1_EPOCHS})"
python -m ginka.train_vq \
--train "$TRAIN_DATA" \
--validate "$EVAL_DATA" \
--resume True \
--state "$P0_FINAL" \
--load_optim False \
--freeze_vq True \
--epochs "$P1_EPOCHS" \
--checkpoint "$P1_CHECKPOINT"
# 阶段 1 最后一个检查点
_P1_LAST=$(ls -t result/joint/joint-*.pth 2>/dev/null | head -1)
[[ -n "$_P1_LAST" ]] || die "阶段 1 未生成任何检查点result/joint/joint-*.pth"
# 复制为阶段 1 固定终态,供阶段 2 加载
cp "$_P1_LAST" "$P1_FINAL"
log "阶段 1 完成 → $P1_FINAL(来自 $_P1_LAST"
else
[[ -f "$P1_FINAL" ]] || die "跳过阶段 1 但找不到检查点:$P1_FINAL"
log "阶段 1 已跳过(使用现有检查点 $P1_FINAL"
fi
# ------------------------------------------------------------------------------
# 阶段 2完整联合训练
# ------------------------------------------------------------------------------
if [[ $START_PHASE -le 2 ]]; then
log "阶段 2 / 3 完整联合训练 (epochs=${P2_EPOCHS})"
python -m ginka.train_vq \
--train "$TRAIN_DATA" \
--validate "$EVAL_DATA" \
--resume True \
--state "$P1_FINAL" \
--load_optim False \
--freeze_vq False \
--epochs "$P2_EPOCHS" \
--checkpoint "$P2_CHECKPOINT"
log "阶段 2 完成"
fi
# ------------------------------------------------------------------------------
echo ""
echo "╔══════════════════════════════════════════╗"
echo "║ 三阶段训练全部完成 ║"
echo "╚══════════════════════════════════════════╝"