mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: vq 预训练
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
789107969b
commit
a69403d6bf
@ -8,9 +8,10 @@
|
||||
|
||||
| 方案 | 核心思路 | 状态 |
|
||||
| ------ | -------------------------------------------------- | -------- |
|
||||
| 方案 A | 重建一致性约束:将生成结果回送编码器,令 z 闭环 | 待细化 |
|
||||
| 方案 A | 重建一致性约束:将生成结果回送编码器,令 z 闭环 | 已实施 |
|
||||
| 方案 B | 多路分拆编码:将地图按层次结构分拆为多部分分别编码 | 待细化 |
|
||||
| 方案 C | 多阶段生成:先墙壁,再门怪,最后资源 | 后续计划 |
|
||||
| 方案 D | VQ 编码器预训练:先单独训练编码器学会重建,再联合 | 待细化 |
|
||||
|
||||
---
|
||||
|
||||
@ -167,6 +168,80 @@ MaskGIT Cross-Attention(z 作为 memory)
|
||||
|
||||
---
|
||||
|
||||
---
|
||||
|
||||
## 方案 D:VQ 编码器预训练
|
||||
|
||||
### 问题诊断
|
||||
|
||||
当前联合训练时,VQ 编码器和 MaskGIT 从随机初始化开始同步优化。由于编码器尚未学到任何地图语义,早期 z 基本是随机噪声,MaskGIT 无法从中获得有效的条件信号,两者的优化信号相互干扰,容易导致训练早期陷入局部最优或收敛缓慢。
|
||||
|
||||
解决思路:在联合训练开始前,先单独预训练 VQ 编码器,使其具备初步的地图语义理解能力,再以此为初始化启动联合训练。
|
||||
|
||||
### 核心思路
|
||||
|
||||
为 VQ-VAE 临时增加一个轻量解码头(Decoder Head),构成完整的自编码器,以完整地图重建为目标进行预训练:
|
||||
|
||||
$$\mathcal{L}_{pretrain} = \mathcal{L}_{CE}^{recon} + \beta \cdot \mathcal{L}_{commit} + \gamma \cdot \mathcal{L}_{uniform}$$
|
||||
|
||||
其中 $\mathcal{L}_{CE}^{recon}$ 是对全部 169 个位置的交叉熵重建损失(不做掩码,全图重建)。预训练完成后,解码头被丢弃,编码器权重作为联合训练的初始化。
|
||||
|
||||
### 解码头设计
|
||||
|
||||
解码头的职责是将 z_q [B, L, d_z] 还原为 [B, H*W, num_classes],有以下两种设计选项:
|
||||
|
||||
#### 选项 D-1:Cross-Attention 解码头(推荐)
|
||||
|
||||
```
|
||||
z_q [B, L, d_z]
|
||||
│
|
||||
▼
|
||||
可学习位置查询 [B, H*W, d_z](每个格子对应一个 query)
|
||||
│ Cross-Attention(query=位置查询,key/value=z_q)
|
||||
▼
|
||||
线性分类头 → logits [B, H*W, num_classes]
|
||||
```
|
||||
|
||||
与 MaskGIT 的 Cross-Attention 结构高度一致,预训练阶段即可验证"z → 地图"的解码路径是否畅通。解码头参数量小(单层 Cross-Attention + Linear),预训练速度快。
|
||||
|
||||
#### 选项 D-2:简单线性展开(基线)
|
||||
|
||||
```
|
||||
z_q [B, L, d_z]
|
||||
│ Flatten → Linear
|
||||
▼
|
||||
logits [B, H*W, num_classes]
|
||||
```
|
||||
|
||||
实现最简单,但 L × d_z → H\*W × num_classes 的映射会引入大量参数(L=32, d_z=128 时约 512K),且缺乏空间归纳偏置,效果可能较差。
|
||||
|
||||
**推荐选项 D-1**,结构与联合训练阶段的 MaskGIT 解码路径一致,预训练阶段已对"z 作为 Cross-Attention memory 驱动生成"这一机制进行充分热身。
|
||||
|
||||
### 训练策略
|
||||
|
||||
| 阶段 | 模型状态 | 目标 | 建议轮数 |
|
||||
| -------------------- | ----------------------------- | ----------------------------------------- | ------------ |
|
||||
| 阶段 0:预训练 | 编码器 + 临时解码头 | 全图重建,$\mathcal{L}_{pretrain}$ 收敛 | 20–50 epoch |
|
||||
| 阶段 1:联合热身 | 编码器冻结 + MaskGIT 训练 | 让 MaskGIT 先适应固定的 z 分布 | 20–40 epoch |
|
||||
| 阶段 2:完整联合训练 | 全部参数解冻,编码器用较小 LR | 端到端联合优化(可叠加方案 A 一致性约束) | 正常训练轮数 |
|
||||
|
||||
> 阶段 1 的编码器冻结热身建议执行:若直接解冻联合训练,MaskGIT 早期的不稳定梯度可能逐渐覆盖编码器预训练获得的语义。考虑到 MaskGIT 收敛速度相对较慢,热身阶段建议适当延长至 20–40 epoch。
|
||||
|
||||
### 实现要点
|
||||
|
||||
1. **解码头独立模块**:将解码头实现为独立的类(如 `VQDecodeHead`),不修改 `GinkaVQVAE` 的核心结构,预训练结束后直接丢弃,不影响联合训练代码路径。
|
||||
2. **预训练脚本独立**:新增 `ginka/train_pretrain.py`,与联合训练脚本 `train_vq.py` 分离,便于单独调试。
|
||||
3. **权重迁移**:预训练结束后通过 `model_vq.load_state_dict(ckpt['vq_state'], strict=False)` 将编码器权重加载到联合训练中。
|
||||
4. **重建质量指标**:预训练阶段重点监控逐类别准确率(尤其是墙壁 tile=1 的召回率),确认编码器已学到基本的空间结构语义。需注意,codebook 容量远小于训练集数量,预训练的目标更倾向于让编码器学会地图的大致分类,而非像素级完整重建——重建损失在此主要作为分类学习的约束信号。
|
||||
|
||||
### 与其他方案的关系
|
||||
|
||||
- 方案 D 是**独立于方案 A/B 的训练流程优化**,不修改模型推理时的计算图,与方案 A 的一致性约束、方案 B 的多路编码均可叠加使用。
|
||||
- 方案 D 完成后,方案 A 的一致性约束的初始条件更好(编码器已具有语义),收敛应更快、更稳定。
|
||||
- 若最终采用方案 B(多路分拆),每个通道的编码器均可独立预训练后再联合训练。
|
||||
|
||||
---
|
||||
|
||||
## 两方案的对比
|
||||
|
||||
| 维度 | 方案 A(z 闭环) | 方案 B(多路分拆) |
|
||||
@ -182,6 +257,12 @@ MaskGIT Cross-Attention(z 作为 memory)
|
||||
|
||||
## 实施建议
|
||||
|
||||
### 阶段零:预训练编码器(方案 D,可选但推荐)
|
||||
|
||||
1. 实现 `VQDecodeHead`(Cross-Attention 解码头)和独立预训练脚本 `ginka/train_pretrain.py`;
|
||||
2. 以全图重建为目标预训练 VQ 编码器 20–50 epoch,直至重建准确率(尤其是墙壁类)趋于稳定;
|
||||
3. 保存编码器权重,作为阶段一联合训练的初始化。
|
||||
|
||||
### 阶段一:验证方案 A(低风险,快速验证)
|
||||
|
||||
1. 在现有联合训练代码中,对子集 A 的训练步骤增加软分布近似一致性损失;
|
||||
@ -206,6 +287,8 @@ MaskGIT Cross-Attention(z 作为 memory)
|
||||
## 待细化事项
|
||||
|
||||
- [x] 方案 A:一致性损失的权重 $\lambda$ 如何随训练进度调度?→ 先使用常量(初始值 0.1),效果不佳再引入调度策略。
|
||||
- [x] 方案 D:预训练阶段是否对所有子集数据都进行预训练,还是只用完整地图?→ 仅使用完整地图(raw_map)。子集划分的差异体现在输入条件上,但输出目标始终是完整地图,预训练阶段无需区分子集。
|
||||
- [x] 方案 D:预训练完成后联合训练时,编码器是否需要冻结热身阶段?→ 建议执行冻结热身。若直接解冻联合训练,MaskGIT 的不稳定梯度可能逐渐覆盖编码器预训练所获得的语义;考虑到 MaskGIT 收敛较慢,热身 epoch 数适当增大(建议 20–40 epoch)。
|
||||
- [x] 方案 A:单步解码还是多步解码后计算一致性损失?→ 训练时 MaskGIT 只进行单步解码,直接在单步结果上计算,无需多步展开。
|
||||
- [x] 方案 B:通道 2 的"墙壁"是否需要保留,还是只保留入口 + 怪 + 门?→ 保留墙壁。去掉墙壁后剩余内容趋向于散点,缺乏空间结构指导意义。
|
||||
- [x] 方案 B:三路 z 拼接后总长度是否超出 MaskGIT cross-attention 的合理 memory 长度?→ 先直接拼接,如有性能问题再评估截断或压缩策略。
|
||||
|
||||
318
ginka/train_pretrain.py
Normal file
318
ginka/train_pretrain.py
Normal file
@ -0,0 +1,318 @@
|
||||
"""
|
||||
VQ 编码器预训练脚本(方案 D)
|
||||
|
||||
目标:在联合训练开始前,先单独预训练 VQ 编码器,使其学到地图的大致语义分类。
|
||||
解码头(VQDecodeHead)仅在预训练阶段使用,结束后丢弃,权重不迁移到联合训练。
|
||||
|
||||
训练流程(对应设计文档方案 D 三阶段):
|
||||
阶段 0(本脚本):编码器 + 临时解码头,全图重建目标
|
||||
阶段 1(在 train_vq.py 中):编码器冻结 + MaskGIT 热身,启用 --freeze_vq
|
||||
阶段 2(在 train_vq.py 中):完整联合训练,编码器用较小 LR
|
||||
|
||||
用法示例:
|
||||
python -m ginka.train_pretrain
|
||||
python -m ginka.train_pretrain --resume True --state result/pretrain/pretrain-20.pth
|
||||
# 预训练完成后,传入权重路径启动联合训练阶段 1:
|
||||
python -m ginka.train_vq --resume True --state result/pretrain/pretrain_final.pth
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from .vqvae.model import GinkaVQVAE, VQDecodeHead
|
||||
from .dataset import load_data
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 超参数(须与 train_vq.py 中 VQ-VAE 配置保持一致)
|
||||
# ---------------------------------------------------------------------------
|
||||
BATCH_SIZE = 64
|
||||
NUM_CLASSES = 16
|
||||
MAP_SIZE = 13 * 13
|
||||
MAP_H = MAP_W = 13
|
||||
|
||||
# VQ-VAE 超参(保持与 train_vq.py 一致)
|
||||
VQ_L = 32
|
||||
VQ_K = 1
|
||||
VQ_D_Z = 128
|
||||
VQ_D_MODEL= 192
|
||||
VQ_NHEAD = 8
|
||||
VQ_LAYERS = 4
|
||||
VQ_DIM_FF = 512
|
||||
VQ_BETA = 0.5
|
||||
VQ_GAMMA = 0.0
|
||||
|
||||
# 解码头超参
|
||||
DH_NHEAD = 8 # Cross-Attention 头数(VQ_D_Z=128 可被 8 整除)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 设备
|
||||
# ---------------------------------------------------------------------------
|
||||
device = torch.device(
|
||||
"cuda:1" if torch.cuda.is_available()
|
||||
else "mps" if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
os.makedirs("result/pretrain", exist_ok=True)
|
||||
|
||||
disable_tqdm = not sys.stdout.isatty()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 简单数据集:仅返回 raw_map,无子集划分,无掩码
|
||||
# ---------------------------------------------------------------------------
|
||||
class GinkaPretrainDataset(Dataset):
|
||||
"""
|
||||
预训练专用数据集,仅提供完整原始地图(raw_map)和随机数据增强。
|
||||
|
||||
不做子集划分与掩码处理;重建目标为全图所有 169 个位置。
|
||||
"""
|
||||
|
||||
def __init__(self, data_path: str):
|
||||
self.data = load_data(data_path)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
arr = np.array(item['map'], dtype=np.int64) # [H, W]
|
||||
|
||||
# 随机旋转 / 翻转数据增强
|
||||
if np.random.rand() > 0.5:
|
||||
k = np.random.randint(1, 4)
|
||||
arr = np.rot90(arr, k).copy()
|
||||
if np.random.rand() > 0.5:
|
||||
arr = np.fliplr(arr).copy()
|
||||
if np.random.rand() > 0.5:
|
||||
arr = np.flipud(arr).copy()
|
||||
|
||||
raw_map = torch.tensor(arr.reshape(-1), dtype=torch.long) # [H*W]
|
||||
return raw_map
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 参数解析
|
||||
# ---------------------------------------------------------------------------
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="VQ 编码器预训练(方案 D)")
|
||||
parser.add_argument("--resume", type=bool, default=False)
|
||||
parser.add_argument("--state", type=str, default="result/pretrain/pretrain-20.pth",
|
||||
help="续训时加载的检查点路径")
|
||||
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
||||
parser.add_argument("--validate", type=str, default="ginka-eval.json")
|
||||
parser.add_argument("--epochs", type=int, default=50)
|
||||
parser.add_argument("--checkpoint", type=int, default=5,
|
||||
help="每隔多少 epoch 保存检查点并输出验证指标")
|
||||
parser.add_argument("--load_optim", type=bool, default=True)
|
||||
return parser.parse_args()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 验证:计算全图 top-1 准确率及关键类别(墙壁)召回率
|
||||
# ---------------------------------------------------------------------------
|
||||
@torch.no_grad()
|
||||
def validate(
|
||||
model_vq: GinkaVQVAE,
|
||||
decode_head: VQDecodeHead,
|
||||
dataloader_val: DataLoader,
|
||||
) -> dict:
|
||||
model_vq.eval()
|
||||
decode_head.eval()
|
||||
|
||||
total, correct = 0, 0
|
||||
wall_tp, wall_gt = 0, 0 # wall tile=1 的召回
|
||||
class_correct = torch.zeros(NUM_CLASSES, dtype=torch.long)
|
||||
class_total = torch.zeros(NUM_CLASSES, dtype=torch.long)
|
||||
|
||||
for raw_map in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm):
|
||||
raw_map = raw_map.to(device) # [B, H*W]
|
||||
|
||||
z_q, _, _, _, _, _ = model_vq(raw_map)
|
||||
logits = decode_head(z_q) # [B, H*W, C]
|
||||
pred = logits.argmax(dim=-1) # [B, H*W]
|
||||
|
||||
correct += (pred == raw_map).sum().item()
|
||||
total += raw_map.numel()
|
||||
|
||||
# 墙壁召回
|
||||
wall_mask = (raw_map == 1)
|
||||
wall_tp += (pred[wall_mask] == 1).sum().item()
|
||||
wall_gt += wall_mask.sum().item()
|
||||
|
||||
# 逐类别统计
|
||||
for c in range(NUM_CLASSES):
|
||||
mask_c = (raw_map == c)
|
||||
class_correct[c] += (pred[mask_c] == c).sum().item()
|
||||
class_total[c] += mask_c.sum().item()
|
||||
|
||||
acc = correct / max(total, 1)
|
||||
wall_rec = wall_tp / max(wall_gt, 1)
|
||||
|
||||
# 有样本的类别逐一统计
|
||||
per_class = {}
|
||||
for c in range(NUM_CLASSES):
|
||||
if class_total[c] > 0:
|
||||
per_class[c] = class_correct[c].item() / class_total[c].item()
|
||||
|
||||
return {"acc": acc, "wall_recall": wall_rec, "per_class": per_class}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 主训练函数
|
||||
# ---------------------------------------------------------------------------
|
||||
def train():
|
||||
print(f"Using device: {device}")
|
||||
args = parse_arguments()
|
||||
|
||||
# ---- 模型 ----
|
||||
model_vq = GinkaVQVAE(
|
||||
num_classes=NUM_CLASSES,
|
||||
L=VQ_L, K=VQ_K, d_z=VQ_D_Z,
|
||||
d_model=VQ_D_MODEL, nhead=VQ_NHEAD,
|
||||
num_layers=VQ_LAYERS, dim_ff=VQ_DIM_FF,
|
||||
map_size=MAP_SIZE,
|
||||
beta=VQ_BETA, gamma=VQ_GAMMA,
|
||||
).to(device)
|
||||
|
||||
decode_head = VQDecodeHead(
|
||||
num_classes=NUM_CLASSES,
|
||||
d_z=VQ_D_Z,
|
||||
map_size=MAP_SIZE,
|
||||
nhead=DH_NHEAD,
|
||||
).to(device)
|
||||
|
||||
vq_params = sum(p.numel() for p in model_vq.parameters())
|
||||
dh_params = sum(p.numel() for p in decode_head.parameters())
|
||||
print(f"VQ-VAE 参数量: {vq_params:,} ({vq_params/1e6:.3f}M)")
|
||||
print(f"DecodeHead 参数量: {dh_params:,} ({dh_params/1e6:.3f}M)")
|
||||
|
||||
# ---- 数据集 ----
|
||||
dataset_train = GinkaPretrainDataset(args.train)
|
||||
dataset_val = GinkaPretrainDataset(args.validate)
|
||||
dataloader_train = DataLoader(
|
||||
dataset_train, batch_size=BATCH_SIZE, shuffle=True,
|
||||
num_workers=0, pin_memory=(device.type == "cuda"),
|
||||
)
|
||||
dataloader_val = DataLoader(
|
||||
dataset_val, batch_size=BATCH_SIZE, shuffle=False,
|
||||
num_workers=0,
|
||||
)
|
||||
print(f"训练集: {len(dataset_train)} 条 验证集: {len(dataset_val)} 条")
|
||||
|
||||
# ---- 优化器 ----
|
||||
all_params = list(model_vq.parameters()) + list(decode_head.parameters())
|
||||
optimizer = optim.AdamW(all_params, lr=2e-4, weight_decay=1e-2)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer, T_max=args.epochs, eta_min=1e-6
|
||||
)
|
||||
|
||||
# ---- 续训 ----
|
||||
start_epoch = 0
|
||||
if args.resume:
|
||||
ckpt = torch.load(args.state, map_location=device)
|
||||
model_vq.load_state_dict(ckpt["vq_state"], strict=False)
|
||||
if "dh_state" in ckpt:
|
||||
decode_head.load_state_dict(ckpt["dh_state"], strict=False)
|
||||
if args.load_optim and ckpt.get("optim_state") is not None:
|
||||
optimizer.load_state_dict(ckpt["optim_state"])
|
||||
start_epoch = ckpt.get("epoch", 0)
|
||||
print(f"从 epoch {start_epoch} 接续训练。")
|
||||
|
||||
# ---- 训练循环 ----
|
||||
for epoch in tqdm(range(start_epoch, start_epoch + args.epochs),
|
||||
desc="VQ Pretrain", disable=disable_tqdm):
|
||||
model_vq.train()
|
||||
decode_head.train()
|
||||
|
||||
loss_total = 0.0
|
||||
ce_total = 0.0
|
||||
commit_total = 0.0
|
||||
entropy_total = 0.0
|
||||
|
||||
for raw_map in tqdm(dataloader_train, leave=False,
|
||||
desc="Epoch Progress", disable=disable_tqdm):
|
||||
raw_map = raw_map.to(device) # [B, H*W]
|
||||
|
||||
# 1. 编码
|
||||
z_q, _, _, vq_loss, commit_loss, entropy_loss = model_vq(raw_map)
|
||||
|
||||
# 2. 解码→全图重建
|
||||
logits = decode_head(z_q) # [B, H*W, C]
|
||||
ce_loss = F.cross_entropy(
|
||||
logits.permute(0, 2, 1), raw_map # [B, C, H*W] vs [B, H*W]
|
||||
)
|
||||
|
||||
# 3. 总损失(重建 + VQ 正则)
|
||||
loss = ce_loss + vq_loss
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0)
|
||||
optimizer.step()
|
||||
|
||||
loss_total += loss.detach().item()
|
||||
ce_total += ce_loss.detach().item()
|
||||
commit_total += commit_loss.detach().item()
|
||||
entropy_total += entropy_loss.detach().item()
|
||||
|
||||
scheduler.step()
|
||||
|
||||
n = len(dataloader_train)
|
||||
tqdm.write(
|
||||
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
|
||||
f"Epoch {epoch + 1:4d} | "
|
||||
f"Loss {loss_total/n:.5f} "
|
||||
f"CE {ce_total/n:.5f} "
|
||||
f"Commit {commit_total/n:.5f} "
|
||||
f"Entropy {entropy_total/n:.5f} | "
|
||||
f"LR {scheduler.get_last_lr()[0]:.6f}"
|
||||
)
|
||||
|
||||
# ---- 检查点 + 验证 ----
|
||||
if (epoch + 1) % args.checkpoint == 0:
|
||||
ckpt_path = f"result/pretrain/pretrain-{epoch + 1}.pth"
|
||||
torch.save({
|
||||
"epoch": epoch + 1,
|
||||
"vq_state": model_vq.state_dict(),
|
||||
"dh_state": decode_head.state_dict(),
|
||||
"optim_state": optimizer.state_dict(),
|
||||
}, ckpt_path)
|
||||
tqdm.write(f" 检查点已保存: {ckpt_path}")
|
||||
|
||||
metrics = validate(model_vq, decode_head, dataloader_val)
|
||||
acc_str = f" [Validate] Acc {metrics['acc']:.4f} Wall Recall {metrics['wall_recall']:.4f}"
|
||||
|
||||
# 输出有样本的类别准确率
|
||||
pc = metrics["per_class"]
|
||||
detail = " ".join(
|
||||
f"c{c}={v:.3f}" for c, v in sorted(pc.items()) if v < 1.0
|
||||
)
|
||||
if detail:
|
||||
acc_str += f"\n Per-class: {detail}"
|
||||
tqdm.write(acc_str)
|
||||
|
||||
model_vq.train()
|
||||
decode_head.train()
|
||||
|
||||
# ---- 保存最终 VQ 编码器权重 ----
|
||||
final_path = "result/pretrain/pretrain_final.pth"
|
||||
torch.save({
|
||||
"epoch": start_epoch + args.epochs,
|
||||
"vq_state": model_vq.state_dict(),
|
||||
# 不保存解码头:联合训练阶段不需要
|
||||
}, final_path)
|
||||
print(f"\n预训练完成。编码器权重已保存至: {final_path}")
|
||||
print(f"联合训练阶段 1 启动命令(编码器冻结热身):")
|
||||
print(f" python -m ginka.train_vq --resume True --state {final_path} --freeze_vq True")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(4)
|
||||
train()
|
||||
@ -102,6 +102,9 @@ def parse_arguments():
|
||||
parser.add_argument("--checkpoint", type=int, default=5,
|
||||
help="每隔多少 epoch 保存检查点并验证")
|
||||
parser.add_argument("--load_optim", type=bool, default=True)
|
||||
parser.add_argument("--freeze_vq", type=bool, default=False,
|
||||
help="(方案 D 阶段 1)冻结 VQ 编码器,仅训练 MaskGIT。"
|
||||
"适用于预训练权重加载后的热身阶段。")
|
||||
return parser.parse_args()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -583,6 +586,12 @@ def train():
|
||||
if img is not None:
|
||||
tile_dict[name] = img
|
||||
|
||||
# ---- 方案 D 阶段 1:冻结 VQ 编码器 ----
|
||||
if args.freeze_vq:
|
||||
for p in model_vq.parameters():
|
||||
p.requires_grad_(False)
|
||||
print("VQ 编码器已冻结(方案 D 阶段 1:MaskGIT 热身)。")
|
||||
|
||||
# ---- 训练循环 ----
|
||||
for epoch in tqdm(range(start_epoch, start_epoch + args.epochs),
|
||||
desc="Joint Training", disable=disable_tqdm):
|
||||
|
||||
@ -4,6 +4,66 @@ from .quantize import VectorQuantizer
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class VQDecodeHead(nn.Module):
|
||||
"""
|
||||
VQ-VAE 预训练用轻量解码头(Cross-Attention 架构)。
|
||||
|
||||
将 z_q [B, L, d_z] 通过 Cross-Attention 还原为地图 logits [B, H*W, num_classes]。
|
||||
预训练结束后此模块被丢弃,不影响联合训练路径。
|
||||
|
||||
架构:
|
||||
可学习位置查询 [B, H*W, d_z]
|
||||
→ Cross-Attention (query=位置查询, key/value=z_q)
|
||||
→ LayerNorm
|
||||
→ 线性分类头 → logits [B, H*W, num_classes]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: int,
|
||||
d_z: int,
|
||||
map_size: int,
|
||||
nhead: int = 4,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
num_classes: tile 类别数
|
||||
d_z: z 向量维度(须与 GinkaVQVAE 的 d_z 一致)
|
||||
map_size: 地图 token 总数(H * W)
|
||||
nhead: Cross-Attention 的注意力头数(d_z 须能被 nhead 整除)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# 每个格子一个可学习位置查询
|
||||
self.pos_queries = nn.Parameter(torch.randn(1, map_size, d_z) * 0.02)
|
||||
|
||||
# Cross-Attention:query=位置查询,key/value=z_q
|
||||
self.cross_attn = nn.MultiheadAttention(
|
||||
embed_dim=d_z,
|
||||
num_heads=nhead,
|
||||
batch_first=True,
|
||||
dropout=0.0,
|
||||
)
|
||||
self.norm = nn.LayerNorm(d_z)
|
||||
|
||||
# 最终分类头
|
||||
self.classifier = nn.Linear(d_z, num_classes)
|
||||
|
||||
def forward(self, z_q: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
z_q: [B, L, d_z]
|
||||
|
||||
Returns:
|
||||
logits: [B, map_size, num_classes]
|
||||
"""
|
||||
B = z_q.shape[0]
|
||||
q = self.pos_queries.expand(B, -1, -1) # [B, map_size, d_z]
|
||||
out, _ = self.cross_attn(q, z_q, z_q) # [B, map_size, d_z]
|
||||
out = self.norm(out)
|
||||
return self.classifier(out) # [B, map_size, num_classes]
|
||||
|
||||
|
||||
class GinkaVQVAE(nn.Module):
|
||||
"""
|
||||
VQ-VAE 风格地图编码器。
|
||||
|
||||
136
train_full.sh
Normal file
136
train_full.sh
Normal file
@ -0,0 +1,136 @@
|
||||
#!/usr/bin/env bash
|
||||
# ==============================================================================
|
||||
# 三阶段完整训练流水线
|
||||
#
|
||||
# 阶段 0 VQ 编码器预训练 train_pretrain.py
|
||||
# 阶段 1 MaskGIT 热身 train_vq.py --freeze_vq True
|
||||
# 阶段 2 完整联合训练 train_vq.py
|
||||
#
|
||||
# 用法:
|
||||
# bash train_full.sh # 从头开始三阶段训练
|
||||
# bash train_full.sh --skip 1 # 跳过阶段 0,从阶段 1 开始
|
||||
# bash train_full.sh --skip 2 # 跳过阶段 0-1,直接阶段 2
|
||||
# ==============================================================================
|
||||
set -euo pipefail
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 超参配置(按需修改)
|
||||
# ------------------------------------------------------------------------------
|
||||
TRAIN_DATA="ginka-dataset.json"
|
||||
EVAL_DATA="ginka-eval.json"
|
||||
|
||||
# 阶段 0:预训练
|
||||
P0_EPOCHS=50
|
||||
P0_CHECKPOINT=5
|
||||
P0_FINAL="result/pretrain/pretrain_final.pth"
|
||||
|
||||
# 阶段 1:冻结编码器热身
|
||||
P1_EPOCHS=30
|
||||
P1_CHECKPOINT=5
|
||||
P1_FINAL="result/joint/warmup_final.pth"
|
||||
|
||||
# 阶段 2:完整联合训练
|
||||
P2_EPOCHS=400
|
||||
P2_CHECKPOINT=20
|
||||
|
||||
# 从哪个阶段开始(0 = 从头);命令行 --skip N 可覆盖此值
|
||||
START_PHASE=0
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 解析命令行参数
|
||||
# ------------------------------------------------------------------------------
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--skip)
|
||||
START_PHASE="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
echo "未知参数: $1"; exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 工具函数
|
||||
# ------------------------------------------------------------------------------
|
||||
log() {
|
||||
echo ""
|
||||
echo "════════════════════════════════════════════════════════════════"
|
||||
echo " $*"
|
||||
echo " $(date '+%Y-%m-%d %H:%M:%S')"
|
||||
echo "════════════════════════════════════════════════════════════════"
|
||||
}
|
||||
|
||||
die() {
|
||||
echo "[ERROR] $*" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 阶段 0:VQ 编码器预训练
|
||||
# ------------------------------------------------------------------------------
|
||||
if [[ $START_PHASE -le 0 ]]; then
|
||||
log "阶段 0 / 3 VQ 编码器预训练 (epochs=${P0_EPOCHS})"
|
||||
python -m ginka.train_pretrain \
|
||||
--train "$TRAIN_DATA" \
|
||||
--validate "$EVAL_DATA" \
|
||||
--epochs "$P0_EPOCHS" \
|
||||
--checkpoint "$P0_CHECKPOINT"
|
||||
|
||||
[[ -f "$P0_FINAL" ]] || die "阶段 0 未生成预期检查点:$P0_FINAL"
|
||||
log "阶段 0 完成 → $P0_FINAL"
|
||||
else
|
||||
[[ -f "$P0_FINAL" ]] || die "跳过阶段 0 但找不到检查点:$P0_FINAL"
|
||||
log "阶段 0 已跳过(使用现有检查点 $P0_FINAL)"
|
||||
fi
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 阶段 1:MaskGIT 热身(VQ 编码器冻结)
|
||||
# ------------------------------------------------------------------------------
|
||||
if [[ $START_PHASE -le 1 ]]; then
|
||||
log "阶段 1 / 3 MaskGIT 热身(VQ 冻结) (epochs=${P1_EPOCHS})"
|
||||
python -m ginka.train_vq \
|
||||
--train "$TRAIN_DATA" \
|
||||
--validate "$EVAL_DATA" \
|
||||
--resume True \
|
||||
--state "$P0_FINAL" \
|
||||
--load_optim False \
|
||||
--freeze_vq True \
|
||||
--epochs "$P1_EPOCHS" \
|
||||
--checkpoint "$P1_CHECKPOINT"
|
||||
|
||||
# 阶段 1 最后一个检查点
|
||||
_P1_LAST=$(ls -t result/joint/joint-*.pth 2>/dev/null | head -1)
|
||||
[[ -n "$_P1_LAST" ]] || die "阶段 1 未生成任何检查点(result/joint/joint-*.pth)"
|
||||
# 复制为阶段 1 固定终态,供阶段 2 加载
|
||||
cp "$_P1_LAST" "$P1_FINAL"
|
||||
log "阶段 1 完成 → $P1_FINAL(来自 $_P1_LAST)"
|
||||
else
|
||||
[[ -f "$P1_FINAL" ]] || die "跳过阶段 1 但找不到检查点:$P1_FINAL"
|
||||
log "阶段 1 已跳过(使用现有检查点 $P1_FINAL)"
|
||||
fi
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 阶段 2:完整联合训练
|
||||
# ------------------------------------------------------------------------------
|
||||
if [[ $START_PHASE -le 2 ]]; then
|
||||
log "阶段 2 / 3 完整联合训练 (epochs=${P2_EPOCHS})"
|
||||
python -m ginka.train_vq \
|
||||
--train "$TRAIN_DATA" \
|
||||
--validate "$EVAL_DATA" \
|
||||
--resume True \
|
||||
--state "$P1_FINAL" \
|
||||
--load_optim False \
|
||||
--freeze_vq False \
|
||||
--epochs "$P2_EPOCHS" \
|
||||
--checkpoint "$P2_CHECKPOINT"
|
||||
|
||||
log "阶段 2 完成"
|
||||
fi
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
echo ""
|
||||
echo "╔══════════════════════════════════════════╗"
|
||||
echo "║ 三阶段训练全部完成 ║"
|
||||
echo "╚══════════════════════════════════════════╝"
|
||||
Loading…
Reference in New Issue
Block a user