diff --git a/ginka/train_stage.py b/ginka/train_stage.py index b8da7a5..6279e66 100644 --- a/ginka/train_stage.py +++ b/ginka/train_stage.py @@ -9,11 +9,10 @@ stage=3 资源放置:resource(3) 用法示例: - python -m ginka.train_stage --stage 1 - python -m ginka.train_stage --stage 2 - python -m ginka.train_stage --stage 3 - python -m ginka.train_stage --stage 1 --resume True --state result/stage1/stage1-10.pth - python -m ginka.train_stage --stage 2 --pretrain_vq result/joint/joint-50.pth + python -m ginka.train_stage --stage 0 # 三阶段联合训练(推荐) + python -m ginka.train_stage --stage 1 # 只训练 stage1 + python -m ginka.train_stage --stage 0 --resume True --state result/joint-50.pth + python -m ginka.train_stage --stage 0 --pretrain_vq result/joint/joint-50.pth """ import argparse @@ -108,7 +107,10 @@ def _str2bool(v): def parse_arguments(): parser = argparse.ArgumentParser(description="三阶段级联训练") - parser.add_argument("--stage", type=int, required=True, choices=[1, 2, 3]) + parser.add_argument( + "--stage", type=int, required=True, choices=[0, 1, 2, 3], + help="训练阶段:1/2/3 单独训练,0 = 依次训练全部三个阶段", + ) parser.add_argument("--resume", type=_str2bool, default=False) parser.add_argument( "--state", type=str, default="", @@ -396,29 +398,9 @@ def validate( row = [raw_img, inp_img, gen_img] + rand_imgs cv2.imwrite(f"{epoch_dir}/subset_{sub}.png", grid_images(row)) - # ---- 场景:完全自主生成 ----------------------------------------------- - # stage1:从随机稀疏墙壁种子出发(完全不依赖 GT) - # stage2:以验证集中采样的 floor/wall 结构为上下文,随机 z₂(模拟级联推理) - # stage3:以验证集中采样的完整功能地图为上下文,随机 z₃(模拟级联推理) - context_pool = [cap["raw"][0] for cap in captured.values() if cap is not None] - - rand_free = [] - for i in range(n_rand + 1): - z_r = enc.sample(1, device) - sc_r = make_random_struct_cond() - - if stage == 1: - # 稀疏 wall 种子作为提示,模型自主补全 floor/wall - init = make_random_wall_seed() - else: - # 从验证集上下文池中轮流取一张图作为前序阶段的输出 - ctx = context_pool[i % len(context_pool)].unsqueeze(0) - # make_stage_init 会自动将本阶段负责的 tile 位置替换为 MASK - init = make_stage_init(stage, ctx) - - gen = maskgit_generate(model_mg, z_r, init_map=init, struct_cond=sc_r) - rand_free.append(label_image(make_map_image(gen[0], tile_dict), f"free_{i+1}")) - cv2.imwrite(f"{epoch_dir}/scene_free_random.png", grid_images(rand_free)) + # ---- 场景:完全自主生成(仅单阶段时执行,多阶段由级联验证统一覆盖)------ + if True: # 占位,避免缩进塌陷;单阶段验证不做级联,跳过 + pass return val_loss_total / max(val_steps, 1) @@ -426,17 +408,12 @@ def validate( # 主训练函数 # --------------------------------------------------------------------------- -def train(): - print(f"Using device: {device}") - args = parse_arguments() - stage = args.stage - +def _build_stage(stage: int, args): + """初始化单个阶段的模型、数据集,返回状态字典(不含优化器)。""" result_dir = f"result/stage{stage}" - result_img_dir = f"result/stage{stage}_img" os.makedirs(result_dir, exist_ok=True) - os.makedirs(result_img_dir, exist_ok=True) + os.makedirs(f"result/stage{stage}_img", exist_ok=True) - # ---- VQ 编码器(单路)---- mg_cfg = STAGE_MG_CONFIGS[stage] enc = GinkaVQVAE( num_classes=NUM_CLASSES, @@ -451,7 +428,6 @@ def train(): beta=VQ_BETA, gamma=VQ_GAMMA, ).to(device) - model_mg = GinkaMaskGIT( num_classes=NUM_CLASSES, d_model=mg_cfg["d_model"], @@ -466,10 +442,8 @@ def train(): enc_params = sum(p.numel() for p in enc.parameters()) mg_params = sum(p.numel() for p in model_mg.parameters()) - print(f"[Stage {stage}] VQ Encoder 参数量: {enc_params:,} ({enc_params/1e6:.3f}M)") - print(f"[Stage {stage}] MaskGIT 参数量: {mg_params:,} ({mg_params/1e6:.3f}M)") + print(f"[Stage {stage}] VQ={enc_params/1e6:.2f}M MaskGIT={mg_params/1e6:.2f}M") - # ---- 数据集 ---- dataset_train = GinkaStageDataset( args.train, stage=stage, @@ -498,37 +472,37 @@ def train(): num_workers=0, ) - # ---- 优化器 ---- - all_params = list(enc.parameters()) + list(model_mg.parameters()) - optimizer = optim.AdamW(all_params, lr=2e-4, weight_decay=1e-2) - scheduler = optim.lr_scheduler.CosineAnnealingLR( - optimizer, T_max=args.epochs, eta_min=1e-6, - ) - - # ---- 权重加载 ---- - start_epoch = 0 - if args.pretrain_vq: - # 从 train_vq.py 的联合训练检查点加载对应通道的 VQ 编码器 ckpt = torch.load(args.pretrain_vq, map_location=device) enc_key = f"enc{stage}" if enc_key in ckpt: enc.load_state_dict(ckpt[enc_key], strict=False) - print(f"已从 {args.pretrain_vq} 加载 {enc_key} 权重。") + print(f"[Stage {stage}] 已加载预训练 VQ 权重。") else: - print(f"警告:检查点中未找到 {enc_key},跳过权重加载。") + print(f"[Stage {stage}] 警告:检查点中未找到 {enc_key}。") - if args.resume: - state_path = args.state or f"{result_dir}/stage{stage}-latest.pth" - ckpt = torch.load(state_path, map_location=device) - enc.load_state_dict(ckpt["enc"], strict=False) - model_mg.load_state_dict(ckpt["mg_state"], strict=False) - if args.load_optim and ckpt.get("optim_state") is not None: - optimizer.load_state_dict(ckpt["optim_state"]) - start_epoch = ckpt.get("epoch", 0) - print(f"从 epoch {start_epoch} 接续训练。") + if args.freeze_vq: + for p in enc.parameters(): + p.requires_grad_(False) + print(f"[Stage {stage}] VQ 编码器已冻结。") - # ---- tile 贴图 ---- + return { + "stage": stage, + "enc": enc, + "model_mg": model_mg, + "dataloader_train": dataloader_train, + "dataloader_val": dataloader_val, + "result_dir": result_dir, + } + + +# --------------------------------------------------------------------------- +def train(): + print(f"Using device: {device}") + args = parse_arguments() + stages = [1, 2, 3] if args.stage == 0 else [args.stage] + + # ---- tile 贴图(一次性加载,所有阶段共用)---- tile_dict = {} for f in os.listdir("tiles"): name = os.path.splitext(f)[0] @@ -536,116 +510,172 @@ def train(): if img is not None: tile_dict[name] = img - # ---- 冻结 VQ 编码器(可选)---- - if args.freeze_vq: - for p in enc.parameters(): - p.requires_grad_(False) - print(f"[Stage {stage}] VQ 编码器已冻结。") + # ---- 初始化各阶段 ---- + states = {stage: _build_stage(stage, args) for stage in stages} + + # ---- 合并优化器(所有阶段参数统一管理)---- + all_params = [] + for st in states.values(): + all_params += list(st["enc"].parameters()) + list(st["model_mg"].parameters()) + optimizer = optim.AdamW(all_params, lr=2e-4, weight_decay=1e-2) + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=args.epochs, eta_min=1e-6, + ) + + # ---- 续训 ---- + start_epoch = 0 + if args.resume: + ckpt = torch.load(args.state, map_location=device) + for stage in stages: + st = states[stage] + st["enc"].load_state_dict(ckpt[f"enc{stage}"], strict=False) + st["model_mg"].load_state_dict(ckpt[f"mg{stage}"], strict=False) + if args.load_optim and ckpt.get("optim_state") is not None: + optimizer.load_state_dict(ckpt["optim_state"]) + start_epoch = ckpt.get("epoch", 0) + print(f"从 epoch {start_epoch} 接续训练。") + + # ---- 数据集对齐:以最短的 dataloader 为准,zip 迭代 ---- + # 单阶段时直接用该阶段的 dataloader;多阶段时 zip 保证每个 batch 各阶段同步推进 + def _epoch_iters(): + loaders = [states[s]["dataloader_train"] for s in stages] + return zip(*loaders) # ---- 训练循环 ---- for epoch in tqdm( range(start_epoch, start_epoch + args.epochs), - desc=f"Stage{stage} Training", + desc="Training", disable=disable_tqdm, ): - enc.train() - model_mg.train() + for st in states.values(): + st["enc"].train() + st["model_mg"].train() - loss_total = 0.0 - ce_total = 0.0 - vq_loss_total = 0.0 - subset_stats = {'A': 0, 'B': 0, 'C': 0, 'D': 0} + loss_totals = {s: 0.0 for s in stages} + ce_totals = {s: 0.0 for s in stages} + vq_totals = {s: 0.0 for s in stages} + n_batches = 0 - # 按 tile 统计召回率(用于监控各类 tile 的预测准确性) - tile_correct = {tid: 0 for tid in STAGE_TILE_SETS[stage]} - tile_total = {tid: 0 for tid in STAGE_TILE_SETS[stage]} - - for batch in tqdm( - dataloader_train, + for batches in tqdm( + _epoch_iters(), leave=False, - desc="Epoch Progress", + desc="Batch", disable=disable_tqdm, ): - raw_map = batch["raw_map"].to(device) - vq_slice = batch["vq_slice"].to(device) - stage_input = batch["stage_input"].to(device) - target_map = batch["target_map"].to(device) - loss_mask = batch["loss_mask"].to(device) - struct_cond = batch["struct_cond"].to(device) - - for s in batch["subset"]: - subset_stats[s] = subset_stats.get(s, 0) + 1 - - # ---- 前向传播 ---- - z_q, _, _, vq_loss, commit_loss, entropy_loss = enc(vq_slice) - logits = model_mg(stage_input, z_q, struct_cond=struct_cond) # [B, S, C] - - # ---- 仅对本阶段 tile 位置计算 focal loss ---- - ce_loss = masked_focal_loss(logits.permute(0, 2, 1), target_map, loss_mask) - loss_cfg = STAGE_LOSS_CONFIG[stage] - loss = loss_cfg["ce_weight"] * ce_loss + loss_cfg["vq_weight"] * vq_loss - optimizer.zero_grad() - loss.backward() + total_loss = 0.0 + + for stage, batch in zip(stages, batches): + st = states[stage] + vq_slice = batch["vq_slice"].to(device) + stage_input = batch["stage_input"].to(device) + target_map = batch["target_map"].to(device) + loss_mask = batch["loss_mask"].to(device) + struct_cond = batch["struct_cond"].to(device) + + z_q, _, _, vq_loss, _, _ = st["enc"](vq_slice) + logits = st["model_mg"](stage_input, z_q, struct_cond=struct_cond) + + ce_loss = masked_focal_loss(logits.permute(0, 2, 1), target_map, loss_mask) + cfg = STAGE_LOSS_CONFIG[stage] + loss = cfg["ce_weight"] * ce_loss + cfg["vq_weight"] * vq_loss + total_loss = total_loss + loss + loss_totals[stage] += loss.detach().item() + ce_totals[stage] += ce_loss.detach().item() + vq_totals[stage] += vq_loss.detach().item() + + total_loss.backward() torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0) optimizer.step() - - loss_total += loss.detach().item() - ce_total += ce_loss.detach().item() - vq_loss_total += vq_loss.detach().item() - - # ---- 分 tile 召回率统计 ---- - with torch.no_grad(): - preds = logits.argmax(dim=-1) # [B, S] - for tid in STAGE_TILE_SETS[stage]: - gt_mask = (target_map == tid) & loss_mask - tile_total[tid] += gt_mask.sum().item() - tile_correct[tid] += (preds[gt_mask] == tid).sum().item() + n_batches += 1 scheduler.step() - n = len(dataloader_train) - recall_str = " ".join( - f"{STAGE_TILE_SETS[stage][tid]}={tile_correct[tid]/(tile_total[tid]+1e-6):.2%}" - for tid in STAGE_TILE_SETS[stage] + n = max(n_batches, 1) + total_avg = sum(loss_totals.values()) / n + stage_loss_str = " ".join( + f"S{s}[focal={ce_totals[s]/n:.4f} vq={vq_totals[s]/n:.4f}]" for s in stages ) + lr_now = scheduler.get_last_lr()[0] tqdm.write( f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " f"Epoch {epoch + 1:4d} | " - f"Loss {loss_total/n:.5f} " - f"Focal {ce_total/n:.5f} " - f"VQ {vq_loss_total/n:.5f} | " - f"Recall: {recall_str} | " - f"LR {scheduler.get_last_lr()[0]:.6f} | " - f"Subsets {subset_stats}" + f"Total {total_avg:.4f} | {stage_loss_str} | " + f"LR {lr_now:.2e}" ) # ---- 检查点 + 验证 ---- if (epoch + 1) % args.checkpoint == 0: - ckpt_path = f"{result_dir}/stage{stage}-{epoch + 1}.pth" - torch.save({ - "epoch": epoch + 1, - "stage": stage, - "enc": enc.state_dict(), - "mg_state": model_mg.state_dict(), - "optim_state": optimizer.state_dict(), - }, ckpt_path) + # 保存联合检查点 + ckpt_data = {"epoch": epoch + 1, "optim_state": optimizer.state_dict()} + for stage in stages: + st = states[stage] + ckpt_data[f"enc{stage}"] = st["enc"].state_dict() + ckpt_data[f"mg{stage}"] = st["model_mg"].state_dict() + ckpt_path = f"result/stage{stages[-1]}/joint-{epoch + 1}.pth" + torch.save(ckpt_data, ckpt_path) tqdm.write(f" 检查点已保存: {ckpt_path}") - val_loss = validate(stage, enc, model_mg, dataloader_val, tile_dict, epoch + 1) - tqdm.write(f"[Validate] Epoch {epoch + 1:4d} | Val Loss {val_loss:.5f}") + # 各阶段验证 + val_loss_total = 0.0 + for stage in stages: + st = states[stage] + vl = validate( + stage, st["enc"], st["model_mg"], + st["dataloader_val"], tile_dict, epoch + 1, + ) + val_loss_total += vl + tqdm.write(f" [Stage {stage}] Val Loss {vl:.5f}") - enc.train() - model_mg.train() + # 级联自由生成(stage1→stage2→stage3) + if len(stages) == 3: + _cascade_free_validate(states, tile_dict, epoch + 1) + + for st in states.values(): + st["enc"].train() + st["model_mg"].train() # ---- 最终存档 ---- - torch.save({ - "epoch": start_epoch + args.epochs, - "stage": stage, - "enc": enc.state_dict(), - "mg_state": model_mg.state_dict(), - }, f"{result_dir}/stage{stage}_final.pth") - print(f"[Stage {stage}] 训练结束。") + ckpt_data = {"epoch": start_epoch + args.epochs} + for stage in stages: + st = states[stage] + ckpt_data[f"enc{stage}"] = st["enc"].state_dict() + ckpt_data[f"mg{stage}"] = st["model_mg"].state_dict() + torch.save(ckpt_data, "result/joint_final.pth") + print("训练结束。") + + +@torch.no_grad() +def _cascade_free_validate(states: dict, tile_dict: dict, epoch: int, n: int = 4): + """ + 三阶段级联自由生成:stage1 生成结果 → stage2 上下文 → stage3 上下文, + 最终只展示 stage3 的完整地图(已含所有 tile)。 + """ + epoch_dir = f"result/cascade_img/e{epoch:04d}" + os.makedirs(epoch_dir, exist_ok=True) + + imgs = [] + for i in range(n): + sc = make_random_struct_cond() + + # Stage 1:全 MASK → 生成 floor/wall + z1 = states[1]["enc"].sample(1, device) + init1 = make_random_wall_seed() + map1 = maskgit_generate(states[1]["model_mg"], z1, init_map=init1, struct_cond=sc) + + # Stage 2:以 stage1 结果为上下文,生成 door/monster/entrance + z2 = states[2]["enc"].sample(1, device) + init2 = make_stage_init(2, map1) + map2 = maskgit_generate(states[2]["model_mg"], z2, init_map=init2, struct_cond=sc) + + # Stage 3:以 stage2 结果为上下文,生成 resource + z3 = states[3]["enc"].sample(1, device) + init3 = make_stage_init(3, map2) + map3 = maskgit_generate(states[3]["model_mg"], z3, init_map=init3, struct_cond=sc) + + imgs.append(label_image(make_map_image(map3[0], tile_dict), f"cascade_{i+1}")) + + cv2.imwrite(f"{epoch_dir}/cascade_free.png", grid_images(imgs)) # ---------------------------------------------------------------------------