fix: 三阶段一起训练

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
unanmed 2026-05-07 22:12:33 +08:00
parent 3676958781
commit 08b90881a8

View File

@ -9,11 +9,10 @@
stage=3 资源放置resource(3) stage=3 资源放置resource(3)
用法示例 用法示例
python -m ginka.train_stage --stage 1 python -m ginka.train_stage --stage 0 # 三阶段联合训练(推荐)
python -m ginka.train_stage --stage 2 python -m ginka.train_stage --stage 1 # 只训练 stage1
python -m ginka.train_stage --stage 3 python -m ginka.train_stage --stage 0 --resume True --state result/joint-50.pth
python -m ginka.train_stage --stage 1 --resume True --state result/stage1/stage1-10.pth python -m ginka.train_stage --stage 0 --pretrain_vq result/joint/joint-50.pth
python -m ginka.train_stage --stage 2 --pretrain_vq result/joint/joint-50.pth
""" """
import argparse import argparse
@ -108,7 +107,10 @@ def _str2bool(v):
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser(description="三阶段级联训练") parser = argparse.ArgumentParser(description="三阶段级联训练")
parser.add_argument("--stage", type=int, required=True, choices=[1, 2, 3]) parser.add_argument(
"--stage", type=int, required=True, choices=[0, 1, 2, 3],
help="训练阶段1/2/3 单独训练0 = 依次训练全部三个阶段",
)
parser.add_argument("--resume", type=_str2bool, default=False) parser.add_argument("--resume", type=_str2bool, default=False)
parser.add_argument( parser.add_argument(
"--state", type=str, default="", "--state", type=str, default="",
@ -396,29 +398,9 @@ def validate(
row = [raw_img, inp_img, gen_img] + rand_imgs row = [raw_img, inp_img, gen_img] + rand_imgs
cv2.imwrite(f"{epoch_dir}/subset_{sub}.png", grid_images(row)) cv2.imwrite(f"{epoch_dir}/subset_{sub}.png", grid_images(row))
# ---- 场景:完全自主生成 ----------------------------------------------- # ---- 场景:完全自主生成(仅单阶段时执行,多阶段由级联验证统一覆盖)------
# stage1从随机稀疏墙壁种子出发完全不依赖 GT if True: # 占位,避免缩进塌陷;单阶段验证不做级联,跳过
# stage2以验证集中采样的 floor/wall 结构为上下文,随机 z₂模拟级联推理 pass
# stage3以验证集中采样的完整功能地图为上下文随机 z₃模拟级联推理
context_pool = [cap["raw"][0] for cap in captured.values() if cap is not None]
rand_free = []
for i in range(n_rand + 1):
z_r = enc.sample(1, device)
sc_r = make_random_struct_cond()
if stage == 1:
# 稀疏 wall 种子作为提示,模型自主补全 floor/wall
init = make_random_wall_seed()
else:
# 从验证集上下文池中轮流取一张图作为前序阶段的输出
ctx = context_pool[i % len(context_pool)].unsqueeze(0)
# make_stage_init 会自动将本阶段负责的 tile 位置替换为 MASK
init = make_stage_init(stage, ctx)
gen = maskgit_generate(model_mg, z_r, init_map=init, struct_cond=sc_r)
rand_free.append(label_image(make_map_image(gen[0], tile_dict), f"free_{i+1}"))
cv2.imwrite(f"{epoch_dir}/scene_free_random.png", grid_images(rand_free))
return val_loss_total / max(val_steps, 1) return val_loss_total / max(val_steps, 1)
@ -426,17 +408,12 @@ def validate(
# 主训练函数 # 主训练函数
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def train(): def _build_stage(stage: int, args):
print(f"Using device: {device}") """初始化单个阶段的模型、数据集,返回状态字典(不含优化器)。"""
args = parse_arguments()
stage = args.stage
result_dir = f"result/stage{stage}" result_dir = f"result/stage{stage}"
result_img_dir = f"result/stage{stage}_img"
os.makedirs(result_dir, exist_ok=True) os.makedirs(result_dir, exist_ok=True)
os.makedirs(result_img_dir, exist_ok=True) os.makedirs(f"result/stage{stage}_img", exist_ok=True)
# ---- VQ 编码器(单路)----
mg_cfg = STAGE_MG_CONFIGS[stage] mg_cfg = STAGE_MG_CONFIGS[stage]
enc = GinkaVQVAE( enc = GinkaVQVAE(
num_classes=NUM_CLASSES, num_classes=NUM_CLASSES,
@ -451,7 +428,6 @@ def train():
beta=VQ_BETA, beta=VQ_BETA,
gamma=VQ_GAMMA, gamma=VQ_GAMMA,
).to(device) ).to(device)
model_mg = GinkaMaskGIT( model_mg = GinkaMaskGIT(
num_classes=NUM_CLASSES, num_classes=NUM_CLASSES,
d_model=mg_cfg["d_model"], d_model=mg_cfg["d_model"],
@ -466,10 +442,8 @@ def train():
enc_params = sum(p.numel() for p in enc.parameters()) enc_params = sum(p.numel() for p in enc.parameters())
mg_params = sum(p.numel() for p in model_mg.parameters()) mg_params = sum(p.numel() for p in model_mg.parameters())
print(f"[Stage {stage}] VQ Encoder 参数量: {enc_params:,} ({enc_params/1e6:.3f}M)") print(f"[Stage {stage}] VQ={enc_params/1e6:.2f}M MaskGIT={mg_params/1e6:.2f}M")
print(f"[Stage {stage}] MaskGIT 参数量: {mg_params:,} ({mg_params/1e6:.3f}M)")
# ---- 数据集 ----
dataset_train = GinkaStageDataset( dataset_train = GinkaStageDataset(
args.train, args.train,
stage=stage, stage=stage,
@ -498,37 +472,37 @@ def train():
num_workers=0, num_workers=0,
) )
# ---- 优化器 ----
all_params = list(enc.parameters()) + list(model_mg.parameters())
optimizer = optim.AdamW(all_params, lr=2e-4, weight_decay=1e-2)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=args.epochs, eta_min=1e-6,
)
# ---- 权重加载 ----
start_epoch = 0
if args.pretrain_vq: if args.pretrain_vq:
# 从 train_vq.py 的联合训练检查点加载对应通道的 VQ 编码器
ckpt = torch.load(args.pretrain_vq, map_location=device) ckpt = torch.load(args.pretrain_vq, map_location=device)
enc_key = f"enc{stage}" enc_key = f"enc{stage}"
if enc_key in ckpt: if enc_key in ckpt:
enc.load_state_dict(ckpt[enc_key], strict=False) enc.load_state_dict(ckpt[enc_key], strict=False)
print(f"已从 {args.pretrain_vq} 加载 {enc_key} 权重。") print(f"[Stage {stage}] 已加载预训练 VQ 权重。")
else: else:
print(f"警告:检查点中未找到 {enc_key},跳过权重加载") print(f"[Stage {stage}] 警告:检查点中未找到 {enc_key}")
if args.resume: if args.freeze_vq:
state_path = args.state or f"{result_dir}/stage{stage}-latest.pth" for p in enc.parameters():
ckpt = torch.load(state_path, map_location=device) p.requires_grad_(False)
enc.load_state_dict(ckpt["enc"], strict=False) print(f"[Stage {stage}] VQ 编码器已冻结。")
model_mg.load_state_dict(ckpt["mg_state"], strict=False)
if args.load_optim and ckpt.get("optim_state") is not None:
optimizer.load_state_dict(ckpt["optim_state"])
start_epoch = ckpt.get("epoch", 0)
print(f"从 epoch {start_epoch} 接续训练。")
# ---- tile 贴图 ---- return {
"stage": stage,
"enc": enc,
"model_mg": model_mg,
"dataloader_train": dataloader_train,
"dataloader_val": dataloader_val,
"result_dir": result_dir,
}
# ---------------------------------------------------------------------------
def train():
print(f"Using device: {device}")
args = parse_arguments()
stages = [1, 2, 3] if args.stage == 0 else [args.stage]
# ---- tile 贴图(一次性加载,所有阶段共用)----
tile_dict = {} tile_dict = {}
for f in os.listdir("tiles"): for f in os.listdir("tiles"):
name = os.path.splitext(f)[0] name = os.path.splitext(f)[0]
@ -536,116 +510,172 @@ def train():
if img is not None: if img is not None:
tile_dict[name] = img tile_dict[name] = img
# ---- 冻结 VQ 编码器(可选)---- # ---- 初始化各阶段 ----
if args.freeze_vq: states = {stage: _build_stage(stage, args) for stage in stages}
for p in enc.parameters():
p.requires_grad_(False) # ---- 合并优化器(所有阶段参数统一管理)----
print(f"[Stage {stage}] VQ 编码器已冻结。") all_params = []
for st in states.values():
all_params += list(st["enc"].parameters()) + list(st["model_mg"].parameters())
optimizer = optim.AdamW(all_params, lr=2e-4, weight_decay=1e-2)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=args.epochs, eta_min=1e-6,
)
# ---- 续训 ----
start_epoch = 0
if args.resume:
ckpt = torch.load(args.state, map_location=device)
for stage in stages:
st = states[stage]
st["enc"].load_state_dict(ckpt[f"enc{stage}"], strict=False)
st["model_mg"].load_state_dict(ckpt[f"mg{stage}"], strict=False)
if args.load_optim and ckpt.get("optim_state") is not None:
optimizer.load_state_dict(ckpt["optim_state"])
start_epoch = ckpt.get("epoch", 0)
print(f"从 epoch {start_epoch} 接续训练。")
# ---- 数据集对齐:以最短的 dataloader 为准zip 迭代 ----
# 单阶段时直接用该阶段的 dataloader多阶段时 zip 保证每个 batch 各阶段同步推进
def _epoch_iters():
loaders = [states[s]["dataloader_train"] for s in stages]
return zip(*loaders)
# ---- 训练循环 ---- # ---- 训练循环 ----
for epoch in tqdm( for epoch in tqdm(
range(start_epoch, start_epoch + args.epochs), range(start_epoch, start_epoch + args.epochs),
desc=f"Stage{stage} Training", desc="Training",
disable=disable_tqdm, disable=disable_tqdm,
): ):
enc.train() for st in states.values():
model_mg.train() st["enc"].train()
st["model_mg"].train()
loss_total = 0.0 loss_totals = {s: 0.0 for s in stages}
ce_total = 0.0 ce_totals = {s: 0.0 for s in stages}
vq_loss_total = 0.0 vq_totals = {s: 0.0 for s in stages}
subset_stats = {'A': 0, 'B': 0, 'C': 0, 'D': 0} n_batches = 0
# 按 tile 统计召回率(用于监控各类 tile 的预测准确性) for batches in tqdm(
tile_correct = {tid: 0 for tid in STAGE_TILE_SETS[stage]} _epoch_iters(),
tile_total = {tid: 0 for tid in STAGE_TILE_SETS[stage]}
for batch in tqdm(
dataloader_train,
leave=False, leave=False,
desc="Epoch Progress", desc="Batch",
disable=disable_tqdm, disable=disable_tqdm,
): ):
raw_map = batch["raw_map"].to(device)
vq_slice = batch["vq_slice"].to(device)
stage_input = batch["stage_input"].to(device)
target_map = batch["target_map"].to(device)
loss_mask = batch["loss_mask"].to(device)
struct_cond = batch["struct_cond"].to(device)
for s in batch["subset"]:
subset_stats[s] = subset_stats.get(s, 0) + 1
# ---- 前向传播 ----
z_q, _, _, vq_loss, commit_loss, entropy_loss = enc(vq_slice)
logits = model_mg(stage_input, z_q, struct_cond=struct_cond) # [B, S, C]
# ---- 仅对本阶段 tile 位置计算 focal loss ----
ce_loss = masked_focal_loss(logits.permute(0, 2, 1), target_map, loss_mask)
loss_cfg = STAGE_LOSS_CONFIG[stage]
loss = loss_cfg["ce_weight"] * ce_loss + loss_cfg["vq_weight"] * vq_loss
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() total_loss = 0.0
for stage, batch in zip(stages, batches):
st = states[stage]
vq_slice = batch["vq_slice"].to(device)
stage_input = batch["stage_input"].to(device)
target_map = batch["target_map"].to(device)
loss_mask = batch["loss_mask"].to(device)
struct_cond = batch["struct_cond"].to(device)
z_q, _, _, vq_loss, _, _ = st["enc"](vq_slice)
logits = st["model_mg"](stage_input, z_q, struct_cond=struct_cond)
ce_loss = masked_focal_loss(logits.permute(0, 2, 1), target_map, loss_mask)
cfg = STAGE_LOSS_CONFIG[stage]
loss = cfg["ce_weight"] * ce_loss + cfg["vq_weight"] * vq_loss
total_loss = total_loss + loss
loss_totals[stage] += loss.detach().item()
ce_totals[stage] += ce_loss.detach().item()
vq_totals[stage] += vq_loss.detach().item()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0) torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0)
optimizer.step() optimizer.step()
n_batches += 1
loss_total += loss.detach().item()
ce_total += ce_loss.detach().item()
vq_loss_total += vq_loss.detach().item()
# ---- 分 tile 召回率统计 ----
with torch.no_grad():
preds = logits.argmax(dim=-1) # [B, S]
for tid in STAGE_TILE_SETS[stage]:
gt_mask = (target_map == tid) & loss_mask
tile_total[tid] += gt_mask.sum().item()
tile_correct[tid] += (preds[gt_mask] == tid).sum().item()
scheduler.step() scheduler.step()
n = len(dataloader_train) n = max(n_batches, 1)
recall_str = " ".join( total_avg = sum(loss_totals.values()) / n
f"{STAGE_TILE_SETS[stage][tid]}={tile_correct[tid]/(tile_total[tid]+1e-6):.2%}" stage_loss_str = " ".join(
for tid in STAGE_TILE_SETS[stage] f"S{s}[focal={ce_totals[s]/n:.4f} vq={vq_totals[s]/n:.4f}]" for s in stages
) )
lr_now = scheduler.get_last_lr()[0]
tqdm.write( tqdm.write(
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
f"Epoch {epoch + 1:4d} | " f"Epoch {epoch + 1:4d} | "
f"Loss {loss_total/n:.5f} " f"Total {total_avg:.4f} | {stage_loss_str} | "
f"Focal {ce_total/n:.5f} " f"LR {lr_now:.2e}"
f"VQ {vq_loss_total/n:.5f} | "
f"Recall: {recall_str} | "
f"LR {scheduler.get_last_lr()[0]:.6f} | "
f"Subsets {subset_stats}"
) )
# ---- 检查点 + 验证 ---- # ---- 检查点 + 验证 ----
if (epoch + 1) % args.checkpoint == 0: if (epoch + 1) % args.checkpoint == 0:
ckpt_path = f"{result_dir}/stage{stage}-{epoch + 1}.pth" # 保存联合检查点
torch.save({ ckpt_data = {"epoch": epoch + 1, "optim_state": optimizer.state_dict()}
"epoch": epoch + 1, for stage in stages:
"stage": stage, st = states[stage]
"enc": enc.state_dict(), ckpt_data[f"enc{stage}"] = st["enc"].state_dict()
"mg_state": model_mg.state_dict(), ckpt_data[f"mg{stage}"] = st["model_mg"].state_dict()
"optim_state": optimizer.state_dict(), ckpt_path = f"result/stage{stages[-1]}/joint-{epoch + 1}.pth"
}, ckpt_path) torch.save(ckpt_data, ckpt_path)
tqdm.write(f" 检查点已保存: {ckpt_path}") tqdm.write(f" 检查点已保存: {ckpt_path}")
val_loss = validate(stage, enc, model_mg, dataloader_val, tile_dict, epoch + 1) # 各阶段验证
tqdm.write(f"[Validate] Epoch {epoch + 1:4d} | Val Loss {val_loss:.5f}") val_loss_total = 0.0
for stage in stages:
st = states[stage]
vl = validate(
stage, st["enc"], st["model_mg"],
st["dataloader_val"], tile_dict, epoch + 1,
)
val_loss_total += vl
tqdm.write(f" [Stage {stage}] Val Loss {vl:.5f}")
enc.train() # 级联自由生成stage1→stage2→stage3
model_mg.train() if len(stages) == 3:
_cascade_free_validate(states, tile_dict, epoch + 1)
for st in states.values():
st["enc"].train()
st["model_mg"].train()
# ---- 最终存档 ---- # ---- 最终存档 ----
torch.save({ ckpt_data = {"epoch": start_epoch + args.epochs}
"epoch": start_epoch + args.epochs, for stage in stages:
"stage": stage, st = states[stage]
"enc": enc.state_dict(), ckpt_data[f"enc{stage}"] = st["enc"].state_dict()
"mg_state": model_mg.state_dict(), ckpt_data[f"mg{stage}"] = st["model_mg"].state_dict()
}, f"{result_dir}/stage{stage}_final.pth") torch.save(ckpt_data, "result/joint_final.pth")
print(f"[Stage {stage}] 训练结束。") print("训练结束。")
@torch.no_grad()
def _cascade_free_validate(states: dict, tile_dict: dict, epoch: int, n: int = 4):
"""
三阶段级联自由生成stage1 生成结果 stage2 上下文 stage3 上下文
最终只展示 stage3 的完整地图已含所有 tile
"""
epoch_dir = f"result/cascade_img/e{epoch:04d}"
os.makedirs(epoch_dir, exist_ok=True)
imgs = []
for i in range(n):
sc = make_random_struct_cond()
# Stage 1全 MASK → 生成 floor/wall
z1 = states[1]["enc"].sample(1, device)
init1 = make_random_wall_seed()
map1 = maskgit_generate(states[1]["model_mg"], z1, init_map=init1, struct_cond=sc)
# Stage 2以 stage1 结果为上下文,生成 door/monster/entrance
z2 = states[2]["enc"].sample(1, device)
init2 = make_stage_init(2, map1)
map2 = maskgit_generate(states[2]["model_mg"], z2, init_map=init2, struct_cond=sc)
# Stage 3以 stage2 结果为上下文,生成 resource
z3 = states[3]["enc"].sample(1, device)
init3 = make_stage_init(3, map2)
map3 = maskgit_generate(states[3]["model_mg"], z3, init_map=init3, struct_cond=sc)
imgs.append(label_image(make_map_image(map3[0], tile_dict), f"cascade_{i+1}"))
cv2.imwrite(f"{epoch_dir}/cascade_free.png", grid_images(imgs))
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------