ginka-generator/ginka/train_pretrain.py
2026-04-28 17:51:39 +08:00

342 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
VQ 编码器预训练脚本(方案 D
目标:在联合训练开始前,先单独预训练 VQ 编码器,使其学到地图的大致语义分类。
解码头VQDecodeHead仅在预训练阶段使用结束后丢弃权重不迁移到联合训练。
训练流程(对应设计文档方案 D 三阶段):
阶段 0本脚本编码器 + 临时解码头,全图重建目标
阶段 1在 train_vq.py 中):编码器冻结 + MaskGIT 热身,启用 --freeze_vq
阶段 2在 train_vq.py 中):完整联合训练,编码器用较小 LR
用法示例:
python -m ginka.train_pretrain
python -m ginka.train_pretrain --resume True --state result/pretrain/pretrain-20.pth
# 预训练完成后,传入权重路径启动联合训练阶段 1
python -m ginka.train_vq --resume True --state result/pretrain/pretrain_final.pth
"""
import argparse
import os
import sys
from datetime import datetime
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from .vqvae.model import GinkaVQVAE, VQDecodeHead
from .dataset import load_data
# ---------------------------------------------------------------------------
# 超参数(须与 train_vq.py 中 VQ-VAE 配置保持一致)
# ---------------------------------------------------------------------------
BATCH_SIZE = 64
NUM_CLASSES = 16
MAP_SIZE = 13 * 13
MAP_H = MAP_W = 13
# VQ-VAE 超参(保持与 train_vq.py 一致)
VQ_L = 2
VQ_K = 8
VQ_D_Z = 128
VQ_D_MODEL= 192
VQ_NHEAD = 8
VQ_LAYERS = 4
VQ_DIM_FF = 512
VQ_BETA = 0.5
VQ_GAMMA = 0.0
# Focal Loss
FOCAL_GAMMA = 2.0 # focal loss 聚焦参数(越大越关注难例/稀有类别)
# 解码头超参(与编码器对称:同等层数和 FFN 宽度)
DH_NHEAD = 8 # Cross-Attention 头数VQ_D_Z=128 可被 8 整除)
DH_DIM_FF = 512 # FFN 隐层维度(与编码器 VQ_DIM_FF 一致)
DH_LAYERS = 4 # 解码层数(与编码器 VQ_LAYERS 一致)
# ---------------------------------------------------------------------------
# 设备
# ---------------------------------------------------------------------------
device = torch.device(
"cuda:1" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)
os.makedirs("result/pretrain", exist_ok=True)
disable_tqdm = not sys.stdout.isatty()
# ---------------------------------------------------------------------------
# Focal Loss
# ---------------------------------------------------------------------------
def focal_loss(
logits: torch.Tensor,
targets: torch.Tensor,
gamma: float = FOCAL_GAMMA,
) -> torch.Tensor:
"""
多分类 Focal Lossmean 归约FL = -(1 - p_t)^gamma * log(p_t)
相比 CE对已被正确分类的高置信度样本施加更小的权重
迫使模型关注难分类的稀有 tile门/怪/资源等)。
"""
ce = F.cross_entropy(logits, targets, reduction='none')
pt = torch.exp(-ce)
return ((1.0 - pt) ** gamma * ce).mean()
# ---------------------------------------------------------------------------
# 简单数据集:仅返回 raw_map无子集划分无掩码
# ---------------------------------------------------------------------------
class GinkaPretrainDataset(Dataset):
"""
预训练专用数据集仅提供完整原始地图raw_map和随机数据增强。
不做子集划分与掩码处理;重建目标为全图所有 169 个位置。
"""
def __init__(self, data_path: str):
self.data = load_data(data_path)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
arr = np.array(item['map'], dtype=np.int64) # [H, W]
# 随机旋转 / 翻转数据增强
if np.random.rand() > 0.5:
k = np.random.randint(1, 4)
arr = np.rot90(arr, k).copy()
if np.random.rand() > 0.5:
arr = np.fliplr(arr).copy()
if np.random.rand() > 0.5:
arr = np.flipud(arr).copy()
raw_map = torch.tensor(arr.reshape(-1), dtype=torch.long) # [H*W]
return raw_map
# ---------------------------------------------------------------------------
# 参数解析
# ---------------------------------------------------------------------------
def parse_arguments():
parser = argparse.ArgumentParser(description="VQ 编码器预训练(方案 D")
parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--state", type=str, default="result/pretrain/pretrain-20.pth",
help="续训时加载的检查点路径")
parser.add_argument("--train", type=str, default="ginka-dataset.json")
parser.add_argument("--validate", type=str, default="ginka-eval.json")
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--checkpoint", type=int, default=5,
help="每隔多少 epoch 保存检查点并输出验证指标")
parser.add_argument("--load_optim", type=bool, default=True)
return parser.parse_args()
# ---------------------------------------------------------------------------
# 验证:计算全图 top-1 准确率及关键类别(墙壁)召回率
# ---------------------------------------------------------------------------
@torch.no_grad()
def validate(
model_vq: GinkaVQVAE,
decode_head: VQDecodeHead,
dataloader_val: DataLoader,
) -> dict:
model_vq.eval()
decode_head.eval()
total, correct = 0, 0
wall_tp, wall_gt = 0, 0 # wall tile=1 的召回
class_correct = torch.zeros(NUM_CLASSES, dtype=torch.long)
class_total = torch.zeros(NUM_CLASSES, dtype=torch.long)
for raw_map in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm):
raw_map = raw_map.to(device) # [B, H*W]
z_q, _, _, _, _, _ = model_vq(raw_map)
logits = decode_head(z_q) # [B, H*W, C]
pred = logits.argmax(dim=-1) # [B, H*W]
correct += (pred == raw_map).sum().item()
total += raw_map.numel()
# 墙壁召回
wall_mask = (raw_map == 1)
wall_tp += (pred[wall_mask] == 1).sum().item()
wall_gt += wall_mask.sum().item()
# 逐类别统计
for c in range(NUM_CLASSES):
mask_c = (raw_map == c)
class_correct[c] += (pred[mask_c] == c).sum().item()
class_total[c] += mask_c.sum().item()
acc = correct / max(total, 1)
wall_rec = wall_tp / max(wall_gt, 1)
# 有样本的类别逐一统计
per_class = {}
for c in range(NUM_CLASSES):
if class_total[c] > 0:
per_class[c] = class_correct[c].item() / class_total[c].item()
return {"acc": acc, "wall_recall": wall_rec, "per_class": per_class}
# ---------------------------------------------------------------------------
# 主训练函数
# ---------------------------------------------------------------------------
def train():
print(f"Using device: {device}")
args = parse_arguments()
# ---- 模型 ----
model_vq = GinkaVQVAE(
num_classes=NUM_CLASSES,
L=VQ_L, K=VQ_K, d_z=VQ_D_Z,
d_model=VQ_D_MODEL, nhead=VQ_NHEAD,
num_layers=VQ_LAYERS, dim_ff=VQ_DIM_FF,
map_size=MAP_SIZE,
beta=VQ_BETA, gamma=VQ_GAMMA,
).to(device)
decode_head = VQDecodeHead(
num_classes=NUM_CLASSES,
d_z=VQ_D_Z,
map_size=MAP_SIZE,
nhead=DH_NHEAD,
dim_ff=DH_DIM_FF,
num_layers=DH_LAYERS,
).to(device)
vq_params = sum(p.numel() for p in model_vq.parameters())
dh_params = sum(p.numel() for p in decode_head.parameters())
print(f"VQ-VAE 参数量: {vq_params:,} ({vq_params/1e6:.3f}M)")
print(f"DecodeHead 参数量: {dh_params:,} ({dh_params/1e6:.3f}M)")
# ---- 数据集 ----
dataset_train = GinkaPretrainDataset(args.train)
dataset_val = GinkaPretrainDataset(args.validate)
dataloader_train = DataLoader(
dataset_train, batch_size=BATCH_SIZE, shuffle=True,
num_workers=0, pin_memory=(device.type == "cuda"),
)
dataloader_val = DataLoader(
dataset_val, batch_size=BATCH_SIZE, shuffle=False,
num_workers=0,
)
print(f"训练集: {len(dataset_train)} 条 验证集: {len(dataset_val)}")
# ---- 优化器 ----
all_params = list(model_vq.parameters()) + list(decode_head.parameters())
optimizer = optim.AdamW(all_params, lr=2e-4, weight_decay=1e-2)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=args.epochs, eta_min=1e-6
)
# ---- 续训 ----
start_epoch = 0
if args.resume:
ckpt = torch.load(args.state, map_location=device)
model_vq.load_state_dict(ckpt["vq_state"], strict=False)
if "dh_state" in ckpt:
decode_head.load_state_dict(ckpt["dh_state"], strict=False)
if args.load_optim and ckpt.get("optim_state") is not None:
optimizer.load_state_dict(ckpt["optim_state"])
start_epoch = ckpt.get("epoch", 0)
print(f"从 epoch {start_epoch} 接续训练。")
# ---- 训练循环 ----
for epoch in tqdm(range(start_epoch, start_epoch + args.epochs),
desc="VQ Pretrain", disable=disable_tqdm):
model_vq.train()
decode_head.train()
loss_total = 0.0
ce_total = 0.0
commit_total = 0.0
entropy_total = 0.0
for raw_map in tqdm(dataloader_train, leave=False,
desc="Epoch Progress", disable=disable_tqdm):
raw_map = raw_map.to(device) # [B, H*W]
# 1. 编码
z_q, _, _, vq_loss, commit_loss, entropy_loss = model_vq(raw_map)
# 2. 解码→全图重建focal loss 缓解墙壁/空地主导问题)
logits = decode_head(z_q) # [B, H*W, C]
ce_loss = focal_loss(logits.permute(0, 2, 1), raw_map)
# 3. 总损失(重建 + VQ 正则)
loss = ce_loss + vq_loss
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0)
optimizer.step()
loss_total += loss.detach().item()
ce_total += ce_loss.detach().item()
commit_total += commit_loss.detach().item()
entropy_total += entropy_loss.detach().item()
scheduler.step()
n = len(dataloader_train)
tqdm.write(
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
f"Epoch {epoch + 1:4d} | "
f"Loss {loss_total/n:.5f} "
f"Focal {ce_total/n:.5f} "
f"Commit {commit_total/n:.5f} "
f"Entropy {entropy_total/n:.5f} | "
f"LR {scheduler.get_last_lr()[0]:.6f}"
)
# ---- 检查点 + 验证 ----
if (epoch + 1) % args.checkpoint == 0:
ckpt_path = f"result/pretrain/pretrain-{epoch + 1}.pth"
torch.save({
"epoch": epoch + 1,
"vq_state": model_vq.state_dict(),
"dh_state": decode_head.state_dict(),
"optim_state": optimizer.state_dict(),
}, ckpt_path)
tqdm.write(f" 检查点已保存: {ckpt_path}")
metrics = validate(model_vq, decode_head, dataloader_val)
acc_str = f" [Validate] Acc {metrics['acc']:.4f} Wall Recall {metrics['wall_recall']:.4f}"
# 输出有样本的类别准确率
pc = metrics["per_class"]
detail = " ".join(
f"c{c}={v:.3f}" for c, v in sorted(pc.items()) if v < 1.0
)
if detail:
acc_str += f"\n Per-class: {detail}"
tqdm.write(acc_str)
model_vq.train()
decode_head.train()
# ---- 保存最终 VQ 编码器权重 ----
final_path = "result/pretrain/pretrain_final.pth"
torch.save({
"epoch": start_epoch + args.epochs,
"vq_state": model_vq.state_dict(),
# 不保存解码头:联合训练阶段不需要
}, final_path)
print(f"\n预训练完成。编码器权重已保存至: {final_path}")
print(f"联合训练阶段 1 启动命令(编码器冻结热身):")
print(f" python -m ginka.train_vq --resume True --state {final_path} --freeze_vq True")
# ---------------------------------------------------------------------------
if __name__ == "__main__":
torch.set_num_threads(4)
train()