diff --git a/docs/z-improvement-design.md b/docs/z-improvement-design.md index de7c366..2f6fb43 100644 --- a/docs/z-improvement-design.md +++ b/docs/z-improvement-design.md @@ -8,9 +8,10 @@ | 方案 | 核心思路 | 状态 | | ------ | -------------------------------------------------- | -------- | -| 方案 A | 重建一致性约束:将生成结果回送编码器,令 z 闭环 | 待细化 | +| 方案 A | 重建一致性约束:将生成结果回送编码器,令 z 闭环 | 已实施 | | 方案 B | 多路分拆编码:将地图按层次结构分拆为多部分分别编码 | 待细化 | | 方案 C | 多阶段生成:先墙壁,再门怪,最后资源 | 后续计划 | +| 方案 D | VQ 编码器预训练:先单独训练编码器学会重建,再联合 | 待细化 | --- @@ -167,6 +168,80 @@ MaskGIT Cross-Attention(z 作为 memory) --- +--- + +## 方案 D:VQ 编码器预训练 + +### 问题诊断 + +当前联合训练时,VQ 编码器和 MaskGIT 从随机初始化开始同步优化。由于编码器尚未学到任何地图语义,早期 z 基本是随机噪声,MaskGIT 无法从中获得有效的条件信号,两者的优化信号相互干扰,容易导致训练早期陷入局部最优或收敛缓慢。 + +解决思路:在联合训练开始前,先单独预训练 VQ 编码器,使其具备初步的地图语义理解能力,再以此为初始化启动联合训练。 + +### 核心思路 + +为 VQ-VAE 临时增加一个轻量解码头(Decoder Head),构成完整的自编码器,以完整地图重建为目标进行预训练: + +$$\mathcal{L}_{pretrain} = \mathcal{L}_{CE}^{recon} + \beta \cdot \mathcal{L}_{commit} + \gamma \cdot \mathcal{L}_{uniform}$$ + +其中 $\mathcal{L}_{CE}^{recon}$ 是对全部 169 个位置的交叉熵重建损失(不做掩码,全图重建)。预训练完成后,解码头被丢弃,编码器权重作为联合训练的初始化。 + +### 解码头设计 + +解码头的职责是将 z_q [B, L, d_z] 还原为 [B, H*W, num_classes],有以下两种设计选项: + +#### 选项 D-1:Cross-Attention 解码头(推荐) + +``` +z_q [B, L, d_z] + │ + ▼ +可学习位置查询 [B, H*W, d_z](每个格子对应一个 query) + │ Cross-Attention(query=位置查询,key/value=z_q) + ▼ +线性分类头 → logits [B, H*W, num_classes] +``` + +与 MaskGIT 的 Cross-Attention 结构高度一致,预训练阶段即可验证"z → 地图"的解码路径是否畅通。解码头参数量小(单层 Cross-Attention + Linear),预训练速度快。 + +#### 选项 D-2:简单线性展开(基线) + +``` +z_q [B, L, d_z] + │ Flatten → Linear + ▼ +logits [B, H*W, num_classes] +``` + +实现最简单,但 L × d_z → H\*W × num_classes 的映射会引入大量参数(L=32, d_z=128 时约 512K),且缺乏空间归纳偏置,效果可能较差。 + +**推荐选项 D-1**,结构与联合训练阶段的 MaskGIT 解码路径一致,预训练阶段已对"z 作为 Cross-Attention memory 驱动生成"这一机制进行充分热身。 + +### 训练策略 + +| 阶段 | 模型状态 | 目标 | 建议轮数 | +| -------------------- | ----------------------------- | ----------------------------------------- | ------------ | +| 阶段 0:预训练 | 编码器 + 临时解码头 | 全图重建,$\mathcal{L}_{pretrain}$ 收敛 | 20–50 epoch | +| 阶段 1:联合热身 | 编码器冻结 + MaskGIT 训练 | 让 MaskGIT 先适应固定的 z 分布 | 20–40 epoch | +| 阶段 2:完整联合训练 | 全部参数解冻,编码器用较小 LR | 端到端联合优化(可叠加方案 A 一致性约束) | 正常训练轮数 | + +> 阶段 1 的编码器冻结热身建议执行:若直接解冻联合训练,MaskGIT 早期的不稳定梯度可能逐渐覆盖编码器预训练获得的语义。考虑到 MaskGIT 收敛速度相对较慢,热身阶段建议适当延长至 20–40 epoch。 + +### 实现要点 + +1. **解码头独立模块**:将解码头实现为独立的类(如 `VQDecodeHead`),不修改 `GinkaVQVAE` 的核心结构,预训练结束后直接丢弃,不影响联合训练代码路径。 +2. **预训练脚本独立**:新增 `ginka/train_pretrain.py`,与联合训练脚本 `train_vq.py` 分离,便于单独调试。 +3. **权重迁移**:预训练结束后通过 `model_vq.load_state_dict(ckpt['vq_state'], strict=False)` 将编码器权重加载到联合训练中。 +4. **重建质量指标**:预训练阶段重点监控逐类别准确率(尤其是墙壁 tile=1 的召回率),确认编码器已学到基本的空间结构语义。需注意,codebook 容量远小于训练集数量,预训练的目标更倾向于让编码器学会地图的大致分类,而非像素级完整重建——重建损失在此主要作为分类学习的约束信号。 + +### 与其他方案的关系 + +- 方案 D 是**独立于方案 A/B 的训练流程优化**,不修改模型推理时的计算图,与方案 A 的一致性约束、方案 B 的多路编码均可叠加使用。 +- 方案 D 完成后,方案 A 的一致性约束的初始条件更好(编码器已具有语义),收敛应更快、更稳定。 +- 若最终采用方案 B(多路分拆),每个通道的编码器均可独立预训练后再联合训练。 + +--- + ## 两方案的对比 | 维度 | 方案 A(z 闭环) | 方案 B(多路分拆) | @@ -182,6 +257,12 @@ MaskGIT Cross-Attention(z 作为 memory) ## 实施建议 +### 阶段零:预训练编码器(方案 D,可选但推荐) + +1. 实现 `VQDecodeHead`(Cross-Attention 解码头)和独立预训练脚本 `ginka/train_pretrain.py`; +2. 以全图重建为目标预训练 VQ 编码器 20–50 epoch,直至重建准确率(尤其是墙壁类)趋于稳定; +3. 保存编码器权重,作为阶段一联合训练的初始化。 + ### 阶段一:验证方案 A(低风险,快速验证) 1. 在现有联合训练代码中,对子集 A 的训练步骤增加软分布近似一致性损失; @@ -206,6 +287,8 @@ MaskGIT Cross-Attention(z 作为 memory) ## 待细化事项 - [x] 方案 A:一致性损失的权重 $\lambda$ 如何随训练进度调度?→ 先使用常量(初始值 0.1),效果不佳再引入调度策略。 +- [x] 方案 D:预训练阶段是否对所有子集数据都进行预训练,还是只用完整地图?→ 仅使用完整地图(raw_map)。子集划分的差异体现在输入条件上,但输出目标始终是完整地图,预训练阶段无需区分子集。 +- [x] 方案 D:预训练完成后联合训练时,编码器是否需要冻结热身阶段?→ 建议执行冻结热身。若直接解冻联合训练,MaskGIT 的不稳定梯度可能逐渐覆盖编码器预训练所获得的语义;考虑到 MaskGIT 收敛较慢,热身 epoch 数适当增大(建议 20–40 epoch)。 - [x] 方案 A:单步解码还是多步解码后计算一致性损失?→ 训练时 MaskGIT 只进行单步解码,直接在单步结果上计算,无需多步展开。 - [x] 方案 B:通道 2 的"墙壁"是否需要保留,还是只保留入口 + 怪 + 门?→ 保留墙壁。去掉墙壁后剩余内容趋向于散点,缺乏空间结构指导意义。 - [x] 方案 B:三路 z 拼接后总长度是否超出 MaskGIT cross-attention 的合理 memory 长度?→ 先直接拼接,如有性能问题再评估截断或压缩策略。 diff --git a/ginka/train_pretrain.py b/ginka/train_pretrain.py new file mode 100644 index 0000000..a4c9dd4 --- /dev/null +++ b/ginka/train_pretrain.py @@ -0,0 +1,318 @@ +""" +VQ 编码器预训练脚本(方案 D) + +目标:在联合训练开始前,先单独预训练 VQ 编码器,使其学到地图的大致语义分类。 +解码头(VQDecodeHead)仅在预训练阶段使用,结束后丢弃,权重不迁移到联合训练。 + +训练流程(对应设计文档方案 D 三阶段): + 阶段 0(本脚本):编码器 + 临时解码头,全图重建目标 + 阶段 1(在 train_vq.py 中):编码器冻结 + MaskGIT 热身,启用 --freeze_vq + 阶段 2(在 train_vq.py 中):完整联合训练,编码器用较小 LR + +用法示例: + python -m ginka.train_pretrain + python -m ginka.train_pretrain --resume True --state result/pretrain/pretrain-20.pth + # 预训练完成后,传入权重路径启动联合训练阶段 1: + python -m ginka.train_vq --resume True --state result/pretrain/pretrain_final.pth +""" + +import argparse +import os +import sys +from datetime import datetime + +import numpy as np +import torch +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + +from .vqvae.model import GinkaVQVAE, VQDecodeHead +from .dataset import load_data + +# --------------------------------------------------------------------------- +# 超参数(须与 train_vq.py 中 VQ-VAE 配置保持一致) +# --------------------------------------------------------------------------- +BATCH_SIZE = 64 +NUM_CLASSES = 16 +MAP_SIZE = 13 * 13 +MAP_H = MAP_W = 13 + +# VQ-VAE 超参(保持与 train_vq.py 一致) +VQ_L = 32 +VQ_K = 1 +VQ_D_Z = 128 +VQ_D_MODEL= 192 +VQ_NHEAD = 8 +VQ_LAYERS = 4 +VQ_DIM_FF = 512 +VQ_BETA = 0.5 +VQ_GAMMA = 0.0 + +# 解码头超参 +DH_NHEAD = 8 # Cross-Attention 头数(VQ_D_Z=128 可被 8 整除) + +# --------------------------------------------------------------------------- +# 设备 +# --------------------------------------------------------------------------- +device = torch.device( + "cuda:1" if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() + else "cpu" +) + +os.makedirs("result/pretrain", exist_ok=True) + +disable_tqdm = not sys.stdout.isatty() + +# --------------------------------------------------------------------------- +# 简单数据集:仅返回 raw_map,无子集划分,无掩码 +# --------------------------------------------------------------------------- +class GinkaPretrainDataset(Dataset): + """ + 预训练专用数据集,仅提供完整原始地图(raw_map)和随机数据增强。 + + 不做子集划分与掩码处理;重建目标为全图所有 169 个位置。 + """ + + def __init__(self, data_path: str): + self.data = load_data(data_path) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + item = self.data[idx] + arr = np.array(item['map'], dtype=np.int64) # [H, W] + + # 随机旋转 / 翻转数据增强 + if np.random.rand() > 0.5: + k = np.random.randint(1, 4) + arr = np.rot90(arr, k).copy() + if np.random.rand() > 0.5: + arr = np.fliplr(arr).copy() + if np.random.rand() > 0.5: + arr = np.flipud(arr).copy() + + raw_map = torch.tensor(arr.reshape(-1), dtype=torch.long) # [H*W] + return raw_map + +# --------------------------------------------------------------------------- +# 参数解析 +# --------------------------------------------------------------------------- +def parse_arguments(): + parser = argparse.ArgumentParser(description="VQ 编码器预训练(方案 D)") + parser.add_argument("--resume", type=bool, default=False) + parser.add_argument("--state", type=str, default="result/pretrain/pretrain-20.pth", + help="续训时加载的检查点路径") + parser.add_argument("--train", type=str, default="ginka-dataset.json") + parser.add_argument("--validate", type=str, default="ginka-eval.json") + parser.add_argument("--epochs", type=int, default=50) + parser.add_argument("--checkpoint", type=int, default=5, + help="每隔多少 epoch 保存检查点并输出验证指标") + parser.add_argument("--load_optim", type=bool, default=True) + return parser.parse_args() + +# --------------------------------------------------------------------------- +# 验证:计算全图 top-1 准确率及关键类别(墙壁)召回率 +# --------------------------------------------------------------------------- +@torch.no_grad() +def validate( + model_vq: GinkaVQVAE, + decode_head: VQDecodeHead, + dataloader_val: DataLoader, +) -> dict: + model_vq.eval() + decode_head.eval() + + total, correct = 0, 0 + wall_tp, wall_gt = 0, 0 # wall tile=1 的召回 + class_correct = torch.zeros(NUM_CLASSES, dtype=torch.long) + class_total = torch.zeros(NUM_CLASSES, dtype=torch.long) + + for raw_map in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm): + raw_map = raw_map.to(device) # [B, H*W] + + z_q, _, _, _, _, _ = model_vq(raw_map) + logits = decode_head(z_q) # [B, H*W, C] + pred = logits.argmax(dim=-1) # [B, H*W] + + correct += (pred == raw_map).sum().item() + total += raw_map.numel() + + # 墙壁召回 + wall_mask = (raw_map == 1) + wall_tp += (pred[wall_mask] == 1).sum().item() + wall_gt += wall_mask.sum().item() + + # 逐类别统计 + for c in range(NUM_CLASSES): + mask_c = (raw_map == c) + class_correct[c] += (pred[mask_c] == c).sum().item() + class_total[c] += mask_c.sum().item() + + acc = correct / max(total, 1) + wall_rec = wall_tp / max(wall_gt, 1) + + # 有样本的类别逐一统计 + per_class = {} + for c in range(NUM_CLASSES): + if class_total[c] > 0: + per_class[c] = class_correct[c].item() / class_total[c].item() + + return {"acc": acc, "wall_recall": wall_rec, "per_class": per_class} + +# --------------------------------------------------------------------------- +# 主训练函数 +# --------------------------------------------------------------------------- +def train(): + print(f"Using device: {device}") + args = parse_arguments() + + # ---- 模型 ---- + model_vq = GinkaVQVAE( + num_classes=NUM_CLASSES, + L=VQ_L, K=VQ_K, d_z=VQ_D_Z, + d_model=VQ_D_MODEL, nhead=VQ_NHEAD, + num_layers=VQ_LAYERS, dim_ff=VQ_DIM_FF, + map_size=MAP_SIZE, + beta=VQ_BETA, gamma=VQ_GAMMA, + ).to(device) + + decode_head = VQDecodeHead( + num_classes=NUM_CLASSES, + d_z=VQ_D_Z, + map_size=MAP_SIZE, + nhead=DH_NHEAD, + ).to(device) + + vq_params = sum(p.numel() for p in model_vq.parameters()) + dh_params = sum(p.numel() for p in decode_head.parameters()) + print(f"VQ-VAE 参数量: {vq_params:,} ({vq_params/1e6:.3f}M)") + print(f"DecodeHead 参数量: {dh_params:,} ({dh_params/1e6:.3f}M)") + + # ---- 数据集 ---- + dataset_train = GinkaPretrainDataset(args.train) + dataset_val = GinkaPretrainDataset(args.validate) + dataloader_train = DataLoader( + dataset_train, batch_size=BATCH_SIZE, shuffle=True, + num_workers=0, pin_memory=(device.type == "cuda"), + ) + dataloader_val = DataLoader( + dataset_val, batch_size=BATCH_SIZE, shuffle=False, + num_workers=0, + ) + print(f"训练集: {len(dataset_train)} 条 验证集: {len(dataset_val)} 条") + + # ---- 优化器 ---- + all_params = list(model_vq.parameters()) + list(decode_head.parameters()) + optimizer = optim.AdamW(all_params, lr=2e-4, weight_decay=1e-2) + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=args.epochs, eta_min=1e-6 + ) + + # ---- 续训 ---- + start_epoch = 0 + if args.resume: + ckpt = torch.load(args.state, map_location=device) + model_vq.load_state_dict(ckpt["vq_state"], strict=False) + if "dh_state" in ckpt: + decode_head.load_state_dict(ckpt["dh_state"], strict=False) + if args.load_optim and ckpt.get("optim_state") is not None: + optimizer.load_state_dict(ckpt["optim_state"]) + start_epoch = ckpt.get("epoch", 0) + print(f"从 epoch {start_epoch} 接续训练。") + + # ---- 训练循环 ---- + for epoch in tqdm(range(start_epoch, start_epoch + args.epochs), + desc="VQ Pretrain", disable=disable_tqdm): + model_vq.train() + decode_head.train() + + loss_total = 0.0 + ce_total = 0.0 + commit_total = 0.0 + entropy_total = 0.0 + + for raw_map in tqdm(dataloader_train, leave=False, + desc="Epoch Progress", disable=disable_tqdm): + raw_map = raw_map.to(device) # [B, H*W] + + # 1. 编码 + z_q, _, _, vq_loss, commit_loss, entropy_loss = model_vq(raw_map) + + # 2. 解码→全图重建 + logits = decode_head(z_q) # [B, H*W, C] + ce_loss = F.cross_entropy( + logits.permute(0, 2, 1), raw_map # [B, C, H*W] vs [B, H*W] + ) + + # 3. 总损失(重建 + VQ 正则) + loss = ce_loss + vq_loss + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0) + optimizer.step() + + loss_total += loss.detach().item() + ce_total += ce_loss.detach().item() + commit_total += commit_loss.detach().item() + entropy_total += entropy_loss.detach().item() + + scheduler.step() + + n = len(dataloader_train) + tqdm.write( + f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + f"Epoch {epoch + 1:4d} | " + f"Loss {loss_total/n:.5f} " + f"CE {ce_total/n:.5f} " + f"Commit {commit_total/n:.5f} " + f"Entropy {entropy_total/n:.5f} | " + f"LR {scheduler.get_last_lr()[0]:.6f}" + ) + + # ---- 检查点 + 验证 ---- + if (epoch + 1) % args.checkpoint == 0: + ckpt_path = f"result/pretrain/pretrain-{epoch + 1}.pth" + torch.save({ + "epoch": epoch + 1, + "vq_state": model_vq.state_dict(), + "dh_state": decode_head.state_dict(), + "optim_state": optimizer.state_dict(), + }, ckpt_path) + tqdm.write(f" 检查点已保存: {ckpt_path}") + + metrics = validate(model_vq, decode_head, dataloader_val) + acc_str = f" [Validate] Acc {metrics['acc']:.4f} Wall Recall {metrics['wall_recall']:.4f}" + + # 输出有样本的类别准确率 + pc = metrics["per_class"] + detail = " ".join( + f"c{c}={v:.3f}" for c, v in sorted(pc.items()) if v < 1.0 + ) + if detail: + acc_str += f"\n Per-class: {detail}" + tqdm.write(acc_str) + + model_vq.train() + decode_head.train() + + # ---- 保存最终 VQ 编码器权重 ---- + final_path = "result/pretrain/pretrain_final.pth" + torch.save({ + "epoch": start_epoch + args.epochs, + "vq_state": model_vq.state_dict(), + # 不保存解码头:联合训练阶段不需要 + }, final_path) + print(f"\n预训练完成。编码器权重已保存至: {final_path}") + print(f"联合训练阶段 1 启动命令(编码器冻结热身):") + print(f" python -m ginka.train_vq --resume True --state {final_path} --freeze_vq True") + + +# --------------------------------------------------------------------------- +if __name__ == "__main__": + torch.set_num_threads(4) + train() diff --git a/ginka/train_vq.py b/ginka/train_vq.py index fbf42b9..a891582 100644 --- a/ginka/train_vq.py +++ b/ginka/train_vq.py @@ -102,6 +102,9 @@ def parse_arguments(): parser.add_argument("--checkpoint", type=int, default=5, help="每隔多少 epoch 保存检查点并验证") parser.add_argument("--load_optim", type=bool, default=True) + parser.add_argument("--freeze_vq", type=bool, default=False, + help="(方案 D 阶段 1)冻结 VQ 编码器,仅训练 MaskGIT。" + "适用于预训练权重加载后的热身阶段。") return parser.parse_args() # --------------------------------------------------------------------------- @@ -583,6 +586,12 @@ def train(): if img is not None: tile_dict[name] = img + # ---- 方案 D 阶段 1:冻结 VQ 编码器 ---- + if args.freeze_vq: + for p in model_vq.parameters(): + p.requires_grad_(False) + print("VQ 编码器已冻结(方案 D 阶段 1:MaskGIT 热身)。") + # ---- 训练循环 ---- for epoch in tqdm(range(start_epoch, start_epoch + args.epochs), desc="Joint Training", disable=disable_tqdm): diff --git a/ginka/vqvae/model.py b/ginka/vqvae/model.py index add46f7..89cc4c0 100644 --- a/ginka/vqvae/model.py +++ b/ginka/vqvae/model.py @@ -4,6 +4,66 @@ from .quantize import VectorQuantizer from typing import Tuple +class VQDecodeHead(nn.Module): + """ + VQ-VAE 预训练用轻量解码头(Cross-Attention 架构)。 + + 将 z_q [B, L, d_z] 通过 Cross-Attention 还原为地图 logits [B, H*W, num_classes]。 + 预训练结束后此模块被丢弃,不影响联合训练路径。 + + 架构: + 可学习位置查询 [B, H*W, d_z] + → Cross-Attention (query=位置查询, key/value=z_q) + → LayerNorm + → 线性分类头 → logits [B, H*W, num_classes] + """ + + def __init__( + self, + num_classes: int, + d_z: int, + map_size: int, + nhead: int = 4, + ): + """ + Args: + num_classes: tile 类别数 + d_z: z 向量维度(须与 GinkaVQVAE 的 d_z 一致) + map_size: 地图 token 总数(H * W) + nhead: Cross-Attention 的注意力头数(d_z 须能被 nhead 整除) + """ + super().__init__() + + # 每个格子一个可学习位置查询 + self.pos_queries = nn.Parameter(torch.randn(1, map_size, d_z) * 0.02) + + # Cross-Attention:query=位置查询,key/value=z_q + self.cross_attn = nn.MultiheadAttention( + embed_dim=d_z, + num_heads=nhead, + batch_first=True, + dropout=0.0, + ) + self.norm = nn.LayerNorm(d_z) + + # 最终分类头 + self.classifier = nn.Linear(d_z, num_classes) + + def forward(self, z_q: torch.Tensor) -> torch.Tensor: + """ + Args: + z_q: [B, L, d_z] + + Returns: + logits: [B, map_size, num_classes] + """ + B = z_q.shape[0] + q = self.pos_queries.expand(B, -1, -1) # [B, map_size, d_z] + out, _ = self.cross_attn(q, z_q, z_q) # [B, map_size, d_z] + out = self.norm(out) + return self.classifier(out) # [B, map_size, num_classes] + + class GinkaVQVAE(nn.Module): """ VQ-VAE 风格地图编码器。 diff --git a/train_full.sh b/train_full.sh new file mode 100644 index 0000000..4a8409f --- /dev/null +++ b/train_full.sh @@ -0,0 +1,136 @@ +#!/usr/bin/env bash +# ============================================================================== +# 三阶段完整训练流水线 +# +# 阶段 0 VQ 编码器预训练 train_pretrain.py +# 阶段 1 MaskGIT 热身 train_vq.py --freeze_vq True +# 阶段 2 完整联合训练 train_vq.py +# +# 用法: +# bash train_full.sh # 从头开始三阶段训练 +# bash train_full.sh --skip 1 # 跳过阶段 0,从阶段 1 开始 +# bash train_full.sh --skip 2 # 跳过阶段 0-1,直接阶段 2 +# ============================================================================== +set -euo pipefail + +# ------------------------------------------------------------------------------ +# 超参配置(按需修改) +# ------------------------------------------------------------------------------ +TRAIN_DATA="ginka-dataset.json" +EVAL_DATA="ginka-eval.json" + +# 阶段 0:预训练 +P0_EPOCHS=50 +P0_CHECKPOINT=5 +P0_FINAL="result/pretrain/pretrain_final.pth" + +# 阶段 1:冻结编码器热身 +P1_EPOCHS=30 +P1_CHECKPOINT=5 +P1_FINAL="result/joint/warmup_final.pth" + +# 阶段 2:完整联合训练 +P2_EPOCHS=400 +P2_CHECKPOINT=20 + +# 从哪个阶段开始(0 = 从头);命令行 --skip N 可覆盖此值 +START_PHASE=0 + +# ------------------------------------------------------------------------------ +# 解析命令行参数 +# ------------------------------------------------------------------------------ +while [[ $# -gt 0 ]]; do + case "$1" in + --skip) + START_PHASE="$2" + shift 2 + ;; + *) + echo "未知参数: $1"; exit 1 + ;; + esac +done + +# ------------------------------------------------------------------------------ +# 工具函数 +# ------------------------------------------------------------------------------ +log() { + echo "" + echo "════════════════════════════════════════════════════════════════" + echo " $*" + echo " $(date '+%Y-%m-%d %H:%M:%S')" + echo "════════════════════════════════════════════════════════════════" +} + +die() { + echo "[ERROR] $*" >&2 + exit 1 +} + +# ------------------------------------------------------------------------------ +# 阶段 0:VQ 编码器预训练 +# ------------------------------------------------------------------------------ +if [[ $START_PHASE -le 0 ]]; then + log "阶段 0 / 3 VQ 编码器预训练 (epochs=${P0_EPOCHS})" + python -m ginka.train_pretrain \ + --train "$TRAIN_DATA" \ + --validate "$EVAL_DATA" \ + --epochs "$P0_EPOCHS" \ + --checkpoint "$P0_CHECKPOINT" + + [[ -f "$P0_FINAL" ]] || die "阶段 0 未生成预期检查点:$P0_FINAL" + log "阶段 0 完成 → $P0_FINAL" +else + [[ -f "$P0_FINAL" ]] || die "跳过阶段 0 但找不到检查点:$P0_FINAL" + log "阶段 0 已跳过(使用现有检查点 $P0_FINAL)" +fi + +# ------------------------------------------------------------------------------ +# 阶段 1:MaskGIT 热身(VQ 编码器冻结) +# ------------------------------------------------------------------------------ +if [[ $START_PHASE -le 1 ]]; then + log "阶段 1 / 3 MaskGIT 热身(VQ 冻结) (epochs=${P1_EPOCHS})" + python -m ginka.train_vq \ + --train "$TRAIN_DATA" \ + --validate "$EVAL_DATA" \ + --resume True \ + --state "$P0_FINAL" \ + --load_optim False \ + --freeze_vq True \ + --epochs "$P1_EPOCHS" \ + --checkpoint "$P1_CHECKPOINT" + + # 阶段 1 最后一个检查点 + _P1_LAST=$(ls -t result/joint/joint-*.pth 2>/dev/null | head -1) + [[ -n "$_P1_LAST" ]] || die "阶段 1 未生成任何检查点(result/joint/joint-*.pth)" + # 复制为阶段 1 固定终态,供阶段 2 加载 + cp "$_P1_LAST" "$P1_FINAL" + log "阶段 1 完成 → $P1_FINAL(来自 $_P1_LAST)" +else + [[ -f "$P1_FINAL" ]] || die "跳过阶段 1 但找不到检查点:$P1_FINAL" + log "阶段 1 已跳过(使用现有检查点 $P1_FINAL)" +fi + +# ------------------------------------------------------------------------------ +# 阶段 2:完整联合训练 +# ------------------------------------------------------------------------------ +if [[ $START_PHASE -le 2 ]]; then + log "阶段 2 / 3 完整联合训练 (epochs=${P2_EPOCHS})" + python -m ginka.train_vq \ + --train "$TRAIN_DATA" \ + --validate "$EVAL_DATA" \ + --resume True \ + --state "$P1_FINAL" \ + --load_optim False \ + --freeze_vq False \ + --epochs "$P2_EPOCHS" \ + --checkpoint "$P2_CHECKPOINT" + + log "阶段 2 完成" +fi + +# ------------------------------------------------------------------------------ +echo "" +echo "╔══════════════════════════════════════════╗" +echo "║ 三阶段训练全部完成 ║" +echo "╚══════════════════════════════════════════╝"