ginka-generator/ginka/train_seperated.py

701 lines
29 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import math
import os
import sys
import random
from datetime import datetime
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader
from .vqvae.quantize import VectorQuantizer
from .vqvae.model import GinkaVQVAE
from .maskGIT.model import GinkaMaskGIT
from .dataset import GinkaSeperatedDataset
from shared.image import matrix_to_image_cv
# 三阶段级联地图生成训练脚本
#
# 整体架构:
# VQ-VAE三组独立编码器 vq1/vq2/vq3将三阶段地图上下文分别编码为离散潜变量
# 再由共用 VectorQuantizer 统一量化为 z_q
# 三个独立 MaskGITmg1/mg2/mg3分别以 z_q 和 struct_inject 为条件,
# 逐阶段迭代解码地图图块序列。
#
# 三阶段生成目标:
# stage1 → floor / wall地图骨架
# stage2 → door / monster / entrance功能性实体
# stage3 → resource资源点
# 图块 ID 定义:
# 0. 空地 1. 墙壁 2. 门 3. 资源 4. 怪物 5. 入口 6. 掩码MASK_TOKEN
# 共用 VQ-VAE 超参
# 三组编码器vq1/vq2/vq3共享相同超参分别对三阶段地图上下文独立编码
VQ_L = 2 # 码字序列长度(每个编码器输出 L 个码字,量化后合并为 L*3
VQ_K = 8 # codebook 大小(离散码本条目数)
VQ_D_Z = 64 # 码字维度
VQ_BETA = 0.5 # commit loss 权重(防止编码器输出漂离 codebook
VQ_GAMMA = 0.0 # entropy loss 权重(当前未启用)
VQ_LAYERS = 3 # VQ-VAE Transformer 层数
VQ_DIM_FF = 512 # VQ-VAE 前馈网络隐层维度
VQ_D_MODEL = 64 # VQ-VAE Transformer 模型维度
VQ_NHEAD = 8 # VQ-VAE 多头注意力头数
# 第一阶段 MaskGIT 超参
STAGE1_MG_DMODEL = 192
STAGE1_MG_NHEAD = 8
STAGE1_MG_NUM_LAYERS = 6
STAGE1_MG_DIM_FF = 1024
# 第二阶段 MaskGIT 超参
STAGE2_MG_DMODEL = 192
STAGE2_MG_NHEAD = 8
STAGE2_MG_NUM_LAYERS = 6
STAGE2_MG_DIM_FF = 1024
# 第三阶段 MaskGIT 超参
STAGE3_MG_DMODEL = 192
STAGE3_MG_NHEAD = 8
STAGE3_MG_NUM_LAYERS = 6
STAGE3_MG_DIM_FF = 1024
# 三阶段 Focal Loss 损失权重(可调节各阶段对总损失的贡献比例)
STAGE1_FOCAL_WEIGHT = 1.0
STAGE2_FOCAL_WEIGHT = 1.0
STAGE3_FOCAL_WEIGHT = 1.0
# 各阶段 VQ commit loss 权重(当前未单独使用,统一由 VQ_BETA 控制)
STAGE1_VQ_WEIGHT = 0.5
STAGE2_VQ_WEIGHT = 0.5
STAGE3_VQ_WEIGHT = 0.5
# 全局参数
NUM_CLASSES = 7 # 图块类型数
MASK_TOKEN = 6 # 掩码图块
MAP_W = 13 # 地图宽度
MAP_H = 13 # 地图高度
MAP_SIZE = MAP_W * MAP_H # 地图大小
GENERATE_STEP = 18 # MaskGIT 采样步数
SUBSET2_WALL_PROB = 0.7 # 子集2 进行墙壁掩码的概率
SUBSET_WEIGHTS = (0.5, 0.3, 0.2) # 每个子集的概率
MG_Z_DROPOUT = 0.1 # z 隐变量 Dropout 概率
MG_STRUCT_DROPOUT = 0.1 # 结构参量 Dropout 概率
# 损失参数
FOCAL_GAMMA = 2.0 # Focal Loss 参数
VQ_BETA = 0.5 # 承诺损失权重
# 训练超参
BATCH_SIZE = 64 # 每批样本数
LR = 1e-4 # AdamW 初始学习率
MIN_LR = 1e-6 # 余弦退火最低学习率
WEIGHT_DECAY = 1e-4 # L2 正则化系数
EPOCHS = 400 # 总训练轮数
CHECKPOINT = 20 # 每隔多少 epoch 保存检查点并执行验证
device = torch.device(
"cuda:1" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)
disable_tqdm = not sys.stdout.isatty()
def _str2bool(v: str):
if isinstance(v, bool): return v
if v.lower() in ('true', '1', 'yes'): return True
if v.lower() in ('false', '0', 'no'): return False
raise argparse.ArgumentTypeError(f"布尔值应为 True/False收到: {v!r}")
def parse_arguments():
parser = argparse.ArgumentParser(description="三阶段级联训练")
parser.add_argument("--resume", type=_str2bool, default=False)
parser.add_argument("--state", type=str, default="", help="续训时检查点路径")
parser.add_argument("--train", type=str, default="ginka-dataset.json")
parser.add_argument("--validate", type=str, default="ginka-eval.json")
parser.add_argument("--load_optim", type=_str2bool, default=True)
return parser.parse_args()
def build_model(device: torch.device):
# 三组 VQ-VAE 编码器各自独立编码一个阶段的地图上下文encoder_stage1/2/3
# 输出形状均为 [B, L, d_z],拼接后送入共用 quantizer
vq_kwargs = dict(
num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_model=VQ_D_MODEL,
nhead=VQ_NHEAD, num_layers=VQ_LAYERS, dim_ff=VQ_DIM_FF, map_size=MAP_SIZE
)
vq1 = GinkaVQVAE(**vq_kwargs).to(device) # 编码 stage1 上下文floor/wall
vq2 = GinkaVQVAE(**vq_kwargs).to(device) # 编码 stage2 上下文door/monster/entrance
vq3 = GinkaVQVAE(**vq_kwargs).to(device) # 编码 stage3 上下文resource
# 三个独立 MaskGIT 解码器,均接收完整的三阶段 z_q 作为条件
mg1 = GinkaMaskGIT(
num_classes=NUM_CLASSES, d_model=STAGE1_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE1_MG_DIM_FF,
nhead=STAGE1_MG_NHEAD, num_layers=STAGE1_MG_NUM_LAYERS, map_size=MAP_SIZE
).to(device)
mg2 = GinkaMaskGIT(
num_classes=NUM_CLASSES, d_model=STAGE2_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE2_MG_DIM_FF,
nhead=STAGE2_MG_NHEAD, num_layers=STAGE2_MG_NUM_LAYERS, map_size=MAP_SIZE
).to(device)
mg3 = GinkaMaskGIT(
num_classes=NUM_CLASSES, d_model=STAGE3_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE3_MG_DIM_FF,
nhead=STAGE3_MG_NHEAD, num_layers=STAGE3_MG_NUM_LAYERS, map_size=MAP_SIZE
).to(device)
# 六个模型参数合并到同一优化器,端到端联合训练
all_params = (
list(vq1.parameters()) + list(vq2.parameters()) + list(vq3.parameters()) +
list(mg1.parameters()) + list(mg2.parameters()) + list(mg3.parameters())
)
optimizer = optim.AdamW(all_params, lr=LR, weight_decay=1e-4)
# 余弦退火:从 LR 线性衰减至 MIN_LR周期为全部训练轮数
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=MIN_LR)
# 共用 VectorQuantizer不参与梯度更新仅在前向时做码本查表
quantizer = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
return vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler
def focal_loss(logits, target):
# logits: [B, L, C],需转为 [B, C, L] 以匹配 cross_entropy 期望格式
ce = F.cross_entropy(logits.permute(0, 2, 1), target, reduction='none')
pt = torch.exp(-ce) # pt = 模型对正确类的预测概率
# Focal Loss对高置信度样本降低权重让模型更专注于难样本
focal = ((1 - pt) ** FOCAL_GAMMA) * ce
return focal.mean()
def random_struct(device: torch.device) -> torch.Tensor:
# 随机采样一组结构参量,用于无条件自由生成
# struct_inject 格式:[cond_sym(0-7), cond_room(0-2), cond_branch(0-2), cond_outer(0-1)]
cond_sym = random.randint(0, 7) # 地图对称类型
cond_room = random.randint(0, 2) # 房间数量档位
cond_branch = random.randint(0, 2) # 分支复杂度档位
cond_outer = random.randint(0, 1) # 是否有外围走廊
return torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer]).unsqueeze(0).to(device)
def maskgit_sample(
model: torch.nn.Module, inp: torch.Tensor, z: torch.Tensor,
struct: torch.Tensor, steps: int, keep_fixed: bool = True
) -> np.ndarray:
# keep_fixed=True锁定输入中已有的非掩码位使上一阶段结构保持不变
# keep_fixed=False所有位置均可被模型自由重估适合探索更多样的生成结果
current = inp.clone()
# 迭代去掩码:每步根据置信度分数重新决定掩码位置
for step in range(steps):
logits = model(current, z, struct)
probs = F.softmax(logits, dim=-1)
dist = torch.distributions.Categorical(probs)
sampled = dist.sample()
confidences = torch.gather(probs, -1, sampled.unsqueeze(-1)).squeeze(-1)
# 余弦退火调度:随步数推进,保留掩码的位置数量递减至 0
ratio = math.cos(((step + 1) / steps) * math.pi / 2)
num_to_mask = math.floor(ratio * MAP_SIZE)
if keep_fixed:
# 输入中已有的非掩码位(来自上一阶段)保持不变
fixed_mask = (current[0] != MASK_TOKEN)
sampled[0, fixed_mask] = current[0, fixed_mask]
confidences[0, fixed_mask] = 1.0
if num_to_mask > 0:
# 将置信度最低的位重新掩码,留待下一步重新预测
_, mask_indices = torch.topk(confidences[0], k=num_to_mask, largest=False)
sampled[0].scatter_(0, mask_indices, MASK_TOKEN)
current = sampled
if (current[0] == MASK_TOKEN).sum() == 0:
break
# 兜底:若仍有残余掩码位(理论上不应发生),用 argmax 确定性填充
still_masked = (current[0] == MASK_TOKEN)
if still_masked.any():
logits = model(current, z, struct)
current[0, still_masked] = torch.argmax(logits[0, still_masked], dim=-1)
return current[0].cpu().numpy().reshape(MAP_H, MAP_W)
def full_generate_random_z(
input: torch.Tensor,
struct: torch.Tensor,
models: list[torch.nn.Module],
device: torch.device,
keep_fixed: tuple[bool, bool, bool] = (True, True, True)
) -> tuple:
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
with torch.no_grad():
z = quantizer.sample(1, VQ_L, device)
# stage1生成 floor/wall 骨架
pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP, keep_fixed=keep_fixed[0])
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
inp2[inp2 == 0] = MASK_TOKEN # 空地位交由 stage2 填充
# stage2在骨架上生成 door/monster/entrance非零结果覆盖合并
pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[1])
merged12 = pred1_np.copy()
merged12[pred2_np != 0] = pred2_np[pred2_np != 0]
inp3 = torch.tensor(merged12.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
inp3[inp3 == 0] = MASK_TOKEN
# stage3填充 resource
pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[2])
merged123 = merged12.copy()
merged123[pred3_np != 0] = pred3_np[pred3_np != 0]
return pred1_np, merged12, merged123
def full_generate_specific_z(
input: torch.Tensor,
z: torch.Tensor,
struct: torch.Tensor,
models: list[torch.nn.Module],
device: torch.device,
keep_fixed: tuple[bool, bool, bool] = (True, True, True)
) -> tuple:
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
with torch.no_grad():
# 与 full_generate_random_z 相同的三阶段级联,但使用给定的 z
pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP, keep_fixed=keep_fixed[0])
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
inp2[inp2 == 0] = MASK_TOKEN
pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[1])
merged12 = pred1_np.copy()
merged12[pred2_np != 0] = pred2_np[pred2_np != 0]
inp3 = torch.tensor(merged12.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
inp3[inp3 == 0] = MASK_TOKEN
pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[2])
merged123 = merged12.copy()
merged123[pred3_np != 0] = pred3_np[pred3_np != 0]
return pred1_np, merged12, merged123
def annotate(img: np.ndarray, text: str) -> np.ndarray:
# 在图片左上角叠加文字标注(黑色描边 + 白色填充,确保任意背景下可读)
img = img.copy()
cv2.putText(img, text, (2, 14), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 2)
cv2.putText(img, text, (2, 14), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
return img
def rand_keep() -> tuple[bool, bool, bool]:
b = random.choice([True, False])
return (b, b, b)
def keep_label(kf: tuple[bool, bool, bool]) -> str:
return 'fix' if kf[0] else 'free'
# 验证可视化 part13×3 网格行1=编码器输入行2=掩码输入行3=三阶段预测(合并)
def visualize_part1(batch, logits1, logits2, logits3, tile_dict):
SEP = 3
TILE_SIZE = 32
img_h = MAP_H * TILE_SIZE
img_w = MAP_W * TILE_SIZE
def to_img(mat):
return matrix_to_image_cv(mat, tile_dict, TILE_SIZE)
pred1 = torch.argmax(logits1[0], dim=-1).cpu().numpy().reshape(MAP_H, MAP_W)
pred2 = torch.argmax(logits2[0], dim=-1).cpu().numpy().reshape(MAP_H, MAP_W)
pred3 = torch.argmax(logits3[0], dim=-1).cpu().numpy().reshape(MAP_H, MAP_W)
enc1_np = batch["encoder_stage1"][0].numpy().reshape(MAP_H, MAP_W)
enc2_np = batch["encoder_stage2"][0].numpy().reshape(MAP_H, MAP_W)
enc3_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W)
inp1_np = batch["input_stage1"][0].numpy().reshape(MAP_H, MAP_W)
inp2_np = batch["input_stage2"][0].numpy().reshape(MAP_H, MAP_W)
inp3_np = batch["input_stage3"][0].numpy().reshape(MAP_H, MAP_W)
# 将各阶段掩码输入中的 MASK 位用模型预测值填充,保留非掩码位原值
result1 = inp1_np.copy()
result1[inp1_np == MASK_TOKEN] = pred1[inp1_np == MASK_TOKEN]
result2 = inp2_np.copy()
result2[inp2_np == MASK_TOKEN] = pred2[inp2_np == MASK_TOKEN]
result3 = inp3_np.copy()
result3[inp3_np == MASK_TOKEN] = pred3[inp3_np == MASK_TOKEN]
rows = [
[to_img(enc1_np), to_img(enc2_np), to_img(enc3_np)],
[to_img(inp1_np), to_img(inp2_np), to_img(inp3_np)],
[to_img(result1), to_img(result2), to_img(result3)],
]
grid = np.ones((3 * img_h + 4 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
for r, row in enumerate(rows):
for c, img in enumerate(row):
y = SEP + r * (img_h + SEP)
x = SEP + c * (img_w + SEP)
grid[y:y + img_h, x:x + img_w] = img
return grid
# 验证可视化 part2行1=真实地图三阶段行2=stage1 输入与使用真实 z 自回归生成的各阶段结果
def visualize_part2(batch, z_q, models, device, tile_dict):
SEP = 3
TILE_SIZE = 32
img_h = MAP_H * TILE_SIZE
img_w = MAP_W * TILE_SIZE
def to_img(mat):
return matrix_to_image_cv(mat, tile_dict, TILE_SIZE)
inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE)
struct_t = batch["struct_inject"][0:1].to(device)
kf = rand_keep()
auto_pred1_np, auto_merged12, auto_merged123 = full_generate_specific_z(
inp1_t, z_q[0:1], struct_t, models, device, keep_fixed=kf
)
kf_label = 'fix' if kf[0] else 'free'
label1 = f"s1:{kf_label}"
label2 = f"s2:{kf_label}"
label3 = f"s3:{kf_label}"
enc1_np = batch["encoder_stage1"][0].numpy().reshape(MAP_H, MAP_W)
enc2_np = batch["encoder_stage2"][0].numpy().reshape(MAP_H, MAP_W)
enc3_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W)
inp1_np = batch["input_stage1"][0].numpy().reshape(MAP_H, MAP_W)
rows = [
[to_img(enc1_np), to_img(enc2_np), to_img(enc3_np)],
[to_img(inp1_np), annotate(to_img(auto_pred1_np), label1), annotate(to_img(auto_merged12), label2), annotate(to_img(auto_merged123), label3)],
]
grid = np.ones((2 * img_h + 3 * SEP, 4 * img_w + 5 * SEP, 3), dtype=np.uint8) * 255
for r, row in enumerate(rows):
for c, img in enumerate(row):
y = SEP + r * (img_h + SEP)
x = SEP + c * (img_w + SEP)
grid[y:y + img_h, x:x + img_w] = img
return grid
# 验证可视化 part32×3 网格行1=参考输入+相同 struct 随机 z 生成行2=随机 struct 生成
def visualize_part3(batch, models, device, tile_dict):
SEP = 3
TILE_SIZE = 32
img_h = MAP_H * TILE_SIZE
img_w = MAP_W * TILE_SIZE
def to_img(mat):
return matrix_to_image_cv(mat, tile_dict, TILE_SIZE)
inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE)
struct_ref = batch["struct_inject"][0:1].to(device)
inp1_np = batch["input_stage1"][0].numpy().reshape(MAP_H, MAP_W)
row1 = [to_img(inp1_np)]
for _ in range(2):
kf = rand_keep()
_, _, merged123 = full_generate_random_z(inp1_t, struct_ref, models, device, keep_fixed=kf)
row1.append(annotate(to_img(merged123), keep_label(kf)))
row2 = []
for _ in range(3):
kf = rand_keep()
_, _, merged123 = full_generate_random_z(inp1_t, random_struct(device), models, device, keep_fixed=kf)
row2.append(annotate(to_img(merged123), keep_label(kf)))
rows = [row1, row2]
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
for r, row in enumerate(rows):
for c, img in enumerate(row):
y = SEP + r * (img_h + SEP)
x = SEP + c * (img_w + SEP)
grid[y:y + img_h, x:x + img_w] = img
return grid
# 验证可视化 part42×3 网格;以少量随机墙壁作为种子,纯随机 struct+z 自由生成
def visualize_part4(models, device, tile_dict):
SEP = 3
TILE_SIZE = 32
img_h = MAP_H * TILE_SIZE
img_w = MAP_W * TILE_SIZE
def to_img(mat):
return matrix_to_image_cv(mat, tile_dict, TILE_SIZE)
n_walls = random.randint(math.floor(MAP_SIZE * 0.02), math.floor(MAP_SIZE * 0.06))
seed = torch.full((1, MAP_SIZE), MASK_TOKEN, dtype=torch.long, device=device)
wall_pos = torch.randperm(MAP_SIZE, device=device)[:n_walls]
seed[0, wall_pos] = 1
seed_np = seed[0].cpu().numpy().reshape(MAP_H, MAP_W)
results = []
for _ in range(5):
kf = rand_keep()
_, _, merged123 = full_generate_random_z(seed, random_struct(device), models, device, keep_fixed=kf)
results.append(annotate(to_img(merged123), keep_label(kf)))
row1 = [to_img(seed_np)] + results[:2]
row2 = results[2:]
rows = [row1, row2]
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
for r, row in enumerate(rows):
for c, img in enumerate(row):
y = SEP + r * (img_h + SEP)
x = SEP + c * (img_w + SEP)
grid[y:y + img_h, x:x + img_w] = img
return grid
def visualize_validate(
batch, logits1, logits2, logits3, z_q,
models: list[torch.nn.Module], device: torch.device, tile_dict, epoch: int, batch_idx: int
):
save_dir = f"result/seperated/e{epoch}"
os.makedirs(save_dir, exist_ok=True)
cv2.imwrite(f"{save_dir}/val{batch_idx}.png", visualize_part1(batch, logits1, logits2, logits3, tile_dict))
cv2.imwrite(f"{save_dir}/full{batch_idx}.png", visualize_part2(batch, z_q, models, device, tile_dict))
cv2.imwrite(f"{save_dir}/rand{batch_idx}.png", visualize_part3(batch, models, device, tile_dict))
def validate(dataloader: DataLoader, models: list[torch.nn.Module], device: torch.device, tile_dict, epoch: int):
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
# 切换为推理模式(关闭 Dropout / BatchNorm 统计更新)
for m in [vq1, vq2, vq3, mg1, mg2, mg3]:
m.eval()
# 累计各阶段损失(跨所有 batch 求和,最终除以 batch 数得到均值)
loss1_total = torch.Tensor([0]).to(device)
loss2_total = torch.Tensor([0]).to(device)
loss3_total = torch.Tensor([0]).to(device)
commit_total = torch.Tensor([0]).to(device)
idx = 0
with torch.no_grad():
for batch in tqdm(dataloader, leave=False, desc="Validate Progress", disable=disable_tqdm):
# 三阶段各自的掩码输入、预测目标和 VQ 编码器输入
inp1 = batch["input_stage1"].to(device).reshape(-1, MAP_SIZE)
target1 = batch["target_stage1"].to(device).reshape(-1, MAP_SIZE)
enc1 = batch["encoder_stage1"].to(device).reshape(-1, MAP_SIZE)
inp2 = batch["input_stage2"].to(device).reshape(-1, MAP_SIZE)
target2 = batch["target_stage2"].to(device).reshape(-1, MAP_SIZE)
enc2 = batch["encoder_stage2"].to(device).reshape(-1, MAP_SIZE)
inp3 = batch["input_stage3"].to(device).reshape(-1, MAP_SIZE)
target3 = batch["target_stage3"].to(device).reshape(-1, MAP_SIZE)
enc3 = batch["encoder_stage3"].to(device).reshape(-1, MAP_SIZE)
struct = batch["struct_inject"].to(device)
# VQ 编码:各阶段独立编码后拼接、量化
z_e1 = vq1(enc1) # [B, L, d_z]
z_e2 = vq2(enc2)
z_e3 = vq3(enc3)
z_e_all = torch.cat([z_e1, z_e2, z_e3], dim=1) # [B, L*3, d_z]
z_q, _, commit_loss = quantizer(z_e_all) # [B, L*3, d_z]
# 三阶段 MaskGIT 推理(均以完整 z_q 和 struct 为条件)
logits1 = mg1(inp1, z_q, struct)
logits2 = mg2(inp2, z_q, struct)
logits3 = mg3(inp3, z_q, struct)
loss1_total += focal_loss(logits1, target1)
loss2_total += focal_loss(logits2, target2)
loss3_total += focal_loss(logits3, target3)
commit_total += commit_loss
# 每个 batch 生成三种可视化图val/full/rand
visualize_validate(batch, logits1, logits2, logits3, z_q, models, device, tile_dict, epoch, idx)
idx += 1
# 每个 epoch 额外生成一张无条件自由生成图(不依赖任何 batch 样本)
save_dir = f"result/seperated/e{epoch}"
os.makedirs(save_dir, exist_ok=True)
cv2.imwrite(f"{save_dir}/free.png", visualize_part4(models, device, tile_dict))
# 恢复训练模式
for m in [vq1, vq2, vq3, mg1, mg2, mg3]:
m.train()
return loss1_total, loss2_total, loss3_total, commit_total
def train(device: torch.device):
args = parse_arguments()
models = build_model(device)
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
start_epoch = 0
if args.resume:
# 从指定检查点恢复:加载所有模型权重及训练状态
ckpt = torch.load(args.state, map_location=device)
vq1.load_state_dict(ckpt["vq1"])
vq2.load_state_dict(ckpt["vq2"])
vq3.load_state_dict(ckpt["vq3"])
mg1.load_state_dict(ckpt["mg1"])
mg2.load_state_dict(ckpt["mg2"])
mg3.load_state_dict(ckpt["mg3"])
quantizer.load_state_dict(ckpt["quantizer"])
# load_optim=False 时可跳过优化器/调度器恢复(适合调整学习率后继续训练)
if args.load_optim and "optimizer" in ckpt:
optimizer.load_state_dict(ckpt["optimizer"])
if args.load_optim and "scheduler" in ckpt:
scheduler.load_state_dict(ckpt["scheduler"])
start_epoch = ckpt.get("epoch", 0) # 从上次保存的 epoch 继续
tqdm.write(f"Resumed from epoch {start_epoch}: {args.state}")
os.makedirs("result/seperated", exist_ok=True)
dataset = GinkaSeperatedDataset(
args.train, subset_weights=SUBSET_WEIGHTS, subset2_wall_prob=SUBSET2_WALL_PROB
)
dataloader = DataLoader(
dataset, batch_size=BATCH_SIZE, shuffle=True
)
dataset_val = GinkaSeperatedDataset(
args.validate, subset_weights=SUBSET_WEIGHTS, subset2_wall_prob=SUBSET2_WALL_PROB
)
dataloader_val = DataLoader(
dataset_val, batch_size=min(BATCH_SIZE, len(dataset_val) // 8), shuffle=True
)
# 预加载图块图像,键为文件名(不含扩展名),用于可视化时将 ID 映射为像素图
tile_dict = {}
for f in os.listdir("tiles"):
name = os.path.splitext(f)[0]
img = cv2.imread(f"tiles/{f}", cv2.IMREAD_UNCHANGED)
if img is not None:
tile_dict[name] = img
for epoch in tqdm(range(start_epoch, EPOCHS), desc="Seperated Training", disable=disable_tqdm):
loss_total = torch.Tensor([0]).to(device)
loss1_total = torch.Tensor([0]).to(device)
loss2_total = torch.Tensor([0]).to(device)
loss3_total = torch.Tensor([0]).to(device)
commit_total = torch.Tensor([0]).to(device)
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
# 三阶段各自的掩码输入序列、预测目标和编码器上下文
inp1 = batch["input_stage1"].to(device).reshape(-1, MAP_SIZE)
target1 = batch["target_stage1"].to(device).reshape(-1, MAP_SIZE)
enc1 = batch["encoder_stage1"].to(device).reshape(-1, MAP_SIZE)
inp2 = batch["input_stage2"].to(device).reshape(-1, MAP_SIZE)
target2 = batch["target_stage2"].to(device).reshape(-1, MAP_SIZE)
enc2 = batch["encoder_stage2"].to(device).reshape(-1, MAP_SIZE)
inp3 = batch["input_stage3"].to(device).reshape(-1, MAP_SIZE)
target3 = batch["target_stage3"].to(device).reshape(-1, MAP_SIZE)
enc3 = batch["encoder_stage3"].to(device).reshape(-1, MAP_SIZE)
# 结构条件向量:[cond_sym, cond_room, cond_branch, cond_outer]
struct = batch["struct_inject"].to(device)
optimizer.zero_grad()
# VQ 编码:各阶段编码器分别处理各自上下文切片
z_e1 = vq1(enc1) # [B, L, d_z]
z_e2 = vq2(enc2)
z_e3 = vq3(enc3)
# 合并三阶段编码后量化
z_e_all = torch.cat([z_e1, z_e2, z_e3], dim=1) # [B, L*3, d_z]
z_q, _, commit_loss = quantizer(z_e_all) # [B, L*3, d_z]
# 三阶段 MaskGIT 前向(均接收完整三阶段 z_q
logits1 = mg1(inp1, z_q, struct)
logits2 = mg2(inp2, z_q, struct)
logits3 = mg3(inp3, z_q, struct)
# 三阶段 Focal Loss + VQ commit loss 加权求和
loss1 = focal_loss(logits1, target1)
loss2 = focal_loss(logits2, target2)
loss3 = focal_loss(logits3, target3)
loss1_weighted = STAGE1_FOCAL_WEIGHT * loss1
loss2_weighted = STAGE2_FOCAL_WEIGHT * loss2
loss3_weighted = STAGE3_FOCAL_WEIGHT * loss3
commit_weighted = VQ_BETA * commit_loss
loss = loss1_weighted + loss2_weighted + loss3_weighted + commit_weighted
loss.backward()
optimizer.step()
# detach 后累加,避免保留计算图占用显存
loss_total += loss.detach()
loss1_total += loss1.detach()
loss2_total += loss2.detach()
loss3_total += loss3.detach()
commit_total += commit_loss.detach()
# 每个 epoch 结束后更新学习率
scheduler.step()
data_length = len(dataloader)
tqdm.write(
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
f"E: {epoch + 1} | Loss: {loss_total.item() / data_length:.6f} | "
f"L1: {loss1_total.item() / data_length:.6f} | "
f"L2: {loss2_total.item() / data_length:.6f} | "
f"L3: {loss3_total.item() / data_length:.6f} | "
f"VQ: {commit_total.item() / data_length:.6f} | "
f"LR: {scheduler.get_last_lr()[0]:.6f}"
)
# 每 CHECKPOINT 个 epoch 执行一次验证、可视化和检查点保存
if (epoch + 1) % CHECKPOINT == 0:
losses = validate(dataloader_val, models, device, tile_dict, epoch + 1)
loss1_total, loss2_total, loss3_total, commit_total = losses
loss1_weighted = STAGE1_FOCAL_WEIGHT * loss1_total
loss2_weighted = STAGE2_FOCAL_WEIGHT * loss2_total
loss3_weighted = STAGE3_FOCAL_WEIGHT * loss3_total
commit_weighted = VQ_BETA * commit_total
loss_total = loss1_weighted + loss2_weighted + loss3_weighted + commit_weighted
data_length = len(dataloader_val)
tqdm.write(
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
f"E: {epoch + 1} | Loss: {loss_total.item() / data_length:.6f} | "
f"L1: {loss1_total.item() / data_length:.6f} | "
f"L2: {loss2_total.item() / data_length:.6f} | "
f"L3: {loss3_total.item() / data_length:.6f} | "
f"VQ: {commit_total.item() / data_length:.6f} | "
)
ckpt_path = f"result/seperated/sep-{epoch + 1}.pth"
torch.save({
"epoch": epoch + 1,
"vq1": vq1.state_dict(),
"vq2": vq2.state_dict(),
"vq3": vq3.state_dict(),
"mg1": mg1.state_dict(),
"mg2": mg2.state_dict(),
"mg3": mg3.state_dict(),
"quantizer": quantizer.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
}, ckpt_path)
tqdm.write(f"Saved checkpoint: {ckpt_path}")
# 训练结束后保存最终完整权重(含优化器状态,可用于后续续训或推理)
final_path = "result/seperated.pth"
torch.save({
"epoch": EPOCHS,
"vq1": vq1.state_dict(),
"vq2": vq2.state_dict(),
"vq3": vq3.state_dict(),
"mg1": mg1.state_dict(),
"mg2": mg2.state_dict(),
"mg3": mg3.state_dict(),
"quantizer": quantizer.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
}, final_path)
tqdm.write(f"Training complete. Final model saved: {final_path}")
if __name__ == "__main__":
train(device)