feat: 调整码字容量与策略

This commit is contained in:
unanmed 2026-05-20 14:46:02 +08:00
parent 306d585a28
commit 416aa4dd72
3 changed files with 233 additions and 88 deletions

View File

@ -28,13 +28,13 @@ class GinkaMaskGIT(nn.Module):
self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z)
# 剩余密度投影:将 5 个浮点数 [wall, door, monster, entrance, resource] 投影为 d_z 维 token
self.remain_proj = nn.Linear(5, d_z)
self.remain_proj = nn.Linear(1, d_z)
# z 投影:逐 token 线性变换,保持序列结构
self.z_proj = nn.Linear(d_z, d_z)
# 条件融合投影z_seq_len 个 z token + 2 个结构 token + 1 个剩余密度 token
self.cond_proj = nn.Linear((z_seq_len + 3) * d_z, d_model)
self.cond_proj = nn.Linear((z_seq_len + 2 + 5) * d_z, d_model)
# 纯 encoder Transformer条件向量 c 通过 AdaLN 注入每一层
self.transformer = Transformer(
@ -62,7 +62,7 @@ class GinkaMaskGIT(nn.Module):
], dim=1)
# 剩余密度:连续浮点向量投影为单个 d_z 维 token[B, 1, d_z]
e_remain = self.remain_proj(remain).unsqueeze(1)
e_remain = self.remain_proj(remain.unsqueeze(-1))
# z逐 token 投影,保留序列结构 [B, z_seq_len, d_z]
z_proj = self.z_proj(z)

View File

@ -23,8 +23,8 @@ from shared.image import matrix_to_image_cv
#
# 整体架构:
# VQ-VAE三组独立编码器 vq1/vq2/vq3将三阶段地图上下文分别编码为离散潜变量
# 再由共用 VectorQuantizer 统一量化为 z_q
# 三个独立 MaskGITmg1/mg2/mg3分别以 z_q 和 struct_inject 为条件,
# 再由三个独立 VectorQuantizer 分别量化为 z_q1/z_q2/z_q3
# 三个独立 MaskGITmg1/mg2/mg3分别以各自阶段 z_q 和 struct_inject 为条件,
# 逐阶段迭代解码地图图块序列。
#
# 三阶段生成目标:
@ -37,14 +37,14 @@ from shared.image import matrix_to_image_cv
# 共用 VQ-VAE 超参
# 三组编码器vq1/vq2/vq3共享相同超参分别对三阶段地图上下文独立编码
VQ_L = 2 # 码字序列长度(每个编码器输出 L 个码字,量化后合并为 L*3
VQ_K = 8 # codebook 大小(离散码本条目数)
VQ_L = 16 # 码字序列长度(每个编码器输出 L 个码字,量化后合并为 L*3
VQ_K = 16 # codebook 大小(离散码本条目数)
VQ_D_Z = 64 # 码字维度
VQ_BETA = 0.5 # commit loss 权重(防止编码器输出漂离 codebook
VQ_BETA = 1.0 # commit loss 权重(防止编码器输出漂离 codebook
VQ_GAMMA = 0.0 # entropy loss 权重(当前未启用)
VQ_LAYERS = 3 # VQ-VAE Transformer 层数
VQ_DIM_FF = 512 # VQ-VAE 前馈网络隐层维度
VQ_D_MODEL = 128 # VQ-VAE Transformer 模型维度
VQ_LAYERS = 6 # VQ-VAE Transformer 层数
VQ_DIM_FF = 1024 # VQ-VAE 前馈网络隐层维度
VQ_D_MODEL = 256 # VQ-VAE Transformer 模型维度
VQ_NHEAD = 4 # VQ-VAE 多头注意力头数
# 第一阶段 MaskGIT 超参
@ -65,10 +65,10 @@ STAGE3_MG_NHEAD = 4
STAGE3_MG_NUM_LAYERS = 6
STAGE3_MG_DIM_FF = 1024
# 三阶段 Focal Loss 损失权重(可调节各阶段对总损失的贡献比例)
STAGE1_FOCAL_WEIGHT = 1.0
STAGE2_FOCAL_WEIGHT = 1.0
STAGE3_FOCAL_WEIGHT = 1.0
# 三阶段 Cross Entropy 损失权重(可调节各阶段对总损失的贡献比例)
STAGE1_CE_WEIGHT = 1.0
STAGE2_CE_WEIGHT = 1.0
STAGE3_CE_WEIGHT = 1.0
# 各阶段 VQ commit loss 权重(当前未单独使用,统一由 VQ_BETA 控制)
STAGE1_VQ_WEIGHT = 0.5
@ -95,7 +95,6 @@ MG_Z_DROPOUT = 0.1 # z 隐变量 Dropout 概率
MG_STRUCT_DROPOUT = 0.1 # 结构参量 Dropout 概率
# 损失参数
FOCAL_GAMMA = 2.0 # Focal Loss 参数
VQ_BETA = 0.5 # 承诺损失权重
# 训练超参
@ -105,6 +104,7 @@ MIN_LR = 1e-6 # 余弦退火最低学习率
WEIGHT_DECAY = 1e-4 # L2 正则化系数
EPOCHS = 400 # 总训练轮数
CHECKPOINT = 20 # 每隔多少 epoch 保存检查点并执行验证
REFERENCE_SAMPLE_PROB = 0.2 # 训练时将参考掩码图无梯度自采样 1-3 步的概率
device = torch.device(
"cuda:0" if torch.cuda.is_available()
@ -131,7 +131,7 @@ def parse_arguments():
def build_model(device: torch.device):
# 三组 VQ-VAE 编码器各自独立编码一个阶段的地图上下文encoder_stage1/2/3
# 输出形状均为 [B, L, d_z]拼接后送入共用 quantizer
# 输出形状均为 [B, L, d_z]分别送入各自阶段的 quantizer
vq_kwargs = dict(
num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_model=VQ_D_MODEL,
nhead=VQ_NHEAD, num_layers=VQ_LAYERS, dim_ff=VQ_DIM_FF, map_h=MAP_H, map_w=MAP_W
@ -140,21 +140,21 @@ def build_model(device: torch.device):
vq2 = GinkaVQVAE(**vq_kwargs).to(device) # 编码 stage2 上下文door/monster/entrance
vq3 = GinkaVQVAE(**vq_kwargs).to(device) # 编码 stage3 上下文resource
# 三个独立 MaskGIT 解码器,均接收完整的三阶段 z_q 作为条件
# 三个独立 MaskGIT 解码器,分别接收各自阶段的 z_q 作为条件
mg1 = GinkaMaskGIT(
num_classes=NUM_CLASSES, d_model=STAGE1_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE1_MG_DIM_FF,
nhead=STAGE1_MG_NHEAD, num_layers=STAGE1_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W,
z_seq_len=VQ_L * 3
z_seq_len=VQ_L
).to(device)
mg2 = GinkaMaskGIT(
num_classes=NUM_CLASSES, d_model=STAGE2_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE2_MG_DIM_FF,
nhead=STAGE2_MG_NHEAD, num_layers=STAGE2_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W,
z_seq_len=VQ_L * 3
z_seq_len=VQ_L
).to(device)
mg3 = GinkaMaskGIT(
num_classes=NUM_CLASSES, d_model=STAGE3_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE3_MG_DIM_FF,
nhead=STAGE3_MG_NHEAD, num_layers=STAGE3_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W,
z_seq_len=VQ_L * 3
z_seq_len=VQ_L
).to(device)
# 六个模型参数合并到同一优化器,端到端联合训练
@ -166,18 +166,106 @@ def build_model(device: torch.device):
# 余弦退火:从 LR 线性衰减至 MIN_LR周期为全部训练轮数
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=MIN_LR)
# 共用 VectorQuantizer不参与梯度更新仅在前向时做码本查表
quantizer = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
# 三个独立 VectorQuantizer各阶段使用自己的码本避免语义空间相互干扰
quantizer1 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
quantizer2 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
quantizer3 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
quantizers = (quantizer1, quantizer2, quantizer3)
return vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler
return vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler
def focal_loss(logits, target):
def cross_entropy_loss(logits, target):
# logits: [B, L, C],需转为 [B, C, L] 以匹配 cross_entropy 期望格式
ce = F.cross_entropy(logits.permute(0, 2, 1), target, reduction='none')
pt = torch.exp(-ce) # pt = 模型对正确类的预测概率
# Focal Loss对高置信度样本降低权重让模型更专注于难样本
focal = ((1 - pt) ** FOCAL_GAMMA) * ce
return focal.mean()
return F.cross_entropy(logits.permute(0, 2, 1), target)
def summarize_codebook_hits(code_hits: torch.Tensor) -> tuple[float, float, int]:
total_hits = code_hits.sum()
if total_hits.item() <= 0:
return 0.0, 0.0, 0
probs = code_hits / total_hits
perplexity = torch.exp(
-(probs * torch.log(probs.clamp_min(1e-10))).sum()
).item()
active_codes = int((code_hits > 0).sum().item())
usage_rate = active_codes / code_hits.numel()
return perplexity, usage_rate, active_codes
def quantize_stage_latents(
quantizers: tuple[VectorQuantizer, VectorQuantizer, VectorQuantizer],
z_e1: torch.Tensor,
z_e2: torch.Tensor,
z_e3: torch.Tensor
) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
quantizer1, quantizer2, quantizer3 = quantizers
z_q1, _, commit_loss1, _, code_hits1 = quantizer1(z_e1)
z_q2, _, commit_loss2, _, code_hits2 = quantizer2(z_e2)
z_q3, _, commit_loss3, _, code_hits3 = quantizer3(z_e3)
commit_loss = (commit_loss1 + commit_loss2 + commit_loss3) / 3
code_hits = torch.stack([code_hits1, code_hits2, code_hits3], dim=0)
return (z_q1, z_q2, z_q3), commit_loss, code_hits
def build_reference_rollout_steps(prob: float) -> int:
if random.random() >= prob:
return 0
return random.randint(1, 3)
def sample_reference_inputs(
model: torch.nn.Module,
reference: torch.Tensor,
z_q: torch.Tensor,
struct: torch.Tensor,
target_density: torch.Tensor,
stage: int,
rollout_steps: int
) -> torch.Tensor:
if rollout_steps <= 0:
return reference
sampled_reference = reference.clone()
with torch.no_grad():
current = sampled_reference.clone()
z_q_detached = z_q.detach()
for _ in range(rollout_steps):
masked_positions = current == MASK_TOKEN
masked_counts = masked_positions.sum(dim=1)
if int(masked_counts.sum().item()) <= 0:
break
remain = compute_remaining(current, target_density, stage)
logits = model(current, z_q_detached, struct, remain)
probs = F.softmax(logits, dim=-1)
dist = torch.distributions.Categorical(probs)
predicted = dist.sample()
confidence = torch.gather(
probs,
-1,
predicted.unsqueeze(-1)
).squeeze(-1)
for local_idx in range(current.size(0)):
masked_count = int(masked_counts[local_idx].item())
if masked_count <= 0:
continue
masked_indices = masked_positions[local_idx].nonzero(as_tuple=True)[0]
reveal_count = max(1, math.ceil(masked_count * 0.1))
reveal_count = min(reveal_count, masked_indices.numel())
masked_confidence = confidence[local_idx, masked_indices]
_, reveal_order = torch.topk(
masked_confidence,
k=reveal_count,
largest=True
)
reveal_indices = masked_indices[reveal_order]
current[local_idx, reveal_indices] = predicted[local_idx, reveal_indices]
sampled_reference = current
return sampled_reference
def random_struct(device: torch.device) -> torch.Tensor:
# 随机采样一组结构参量,用于无条件自由生成
@ -336,14 +424,17 @@ def full_generate_random_z(
device: torch.device,
keep_fixed: tuple[bool, bool, bool] = (True, True, True)
) -> tuple:
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models
quantizer1, quantizer2, quantizer3 = quantizers
with torch.no_grad():
z = quantizer.sample(1, VQ_L, device)
z1 = quantizer1.sample(1, VQ_L, device)
z2 = quantizer2.sample(1, VQ_L, device)
z3 = quantizer3.sample(1, VQ_L, device)
# stage1生成墙壁骨架
pred1_np = maskgit_sample(
mg1, input.clone(), z, struct, target_density, 1,
mg1, input.clone(), z1, struct, target_density, 1,
GENERATE_STEP, target_tiles=[1], keep_fixed=keep_fixed[0]
)
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
@ -351,7 +442,7 @@ def full_generate_random_z(
# stage2在骨架上生成 door(2)/monster(4)/entrance(5),非零结果覆盖合并
pred2_np = maskgit_sample(
mg2, inp2, z, struct, target_density, 2,
mg2, inp2, z2, struct, target_density, 2,
GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1]
)
merged12 = pred1_np.copy()
@ -361,7 +452,7 @@ def full_generate_random_z(
# stage3填充 resource(3)
pred3_np = maskgit_sample(
mg3, inp3, z, struct, target_density, 3,
mg3, inp3, z3, struct, target_density, 3,
GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2]
)
merged123 = merged12.copy()
@ -371,26 +462,28 @@ def full_generate_random_z(
def full_generate_specific_z(
input: torch.Tensor,
z: torch.Tensor,
z_q: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
struct: torch.Tensor,
target_density: torch.Tensor,
models: list[torch.nn.Module],
device: torch.device,
keep_fixed: tuple[bool, bool, bool] = (True, True, True)
) -> tuple:
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models
with torch.no_grad():
z1, z2, z3 = z_q
# 与 full_generate_random_z 相同的三阶段级联,但使用给定的 z
pred1_np = maskgit_sample(
mg1, input.clone(), z, struct, target_density, 1,
mg1, input.clone(), z1, struct, target_density, 1,
GENERATE_STEP, target_tiles=[1], keep_fixed=keep_fixed[0]
)
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
inp2[inp2 == 0] = MASK_TOKEN
pred2_np = maskgit_sample(
mg2, inp2, z, struct, target_density, 2,
mg2, inp2, z2, struct, target_density, 2,
GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1]
)
merged12 = pred1_np.copy()
@ -399,7 +492,7 @@ def full_generate_specific_z(
inp3[inp3 == 0] = MASK_TOKEN
pred3_np = maskgit_sample(
mg3, inp3, z, struct, target_density, 3,
mg3, inp3, z3, struct, target_density, 3,
GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2]
)
merged123 = merged12.copy()
@ -493,9 +586,10 @@ def visualize_part2(batch, z_q, models, device, tile_dict):
inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE)
struct_t = batch["struct_inject"][0:1].to(device)
target_density_t = batch["target_density"][0:1].to(device)
z_q_single = (z_q[0][0:1], z_q[1][0:1], z_q[2][0:1])
kf = rand_keep()
auto_pred1_np, auto_merged12, auto_merged123 = full_generate_specific_z(
inp1_t, z_q[0:1], struct_t, target_density_t, models, device, keep_fixed=kf
inp1_t, z_q_single, struct_t, target_density_t, models, device, keep_fixed=kf
)
kf_label = 'fix' if kf[0] else 'free'
@ -665,6 +759,7 @@ def visualize_density_var(batch, z_q, models, device, tile_dict):
struct_t = batch["struct_inject"][0:1].to(device)
struct_cpu = batch["struct_inject"][0]
base_target_density = batch["target_density"][0:1].to(device)
z_q_single = (z_q[0][0:1], z_q[1][0:1], z_q[2][0:1])
ref_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W)
gen_imgs = []
wall_count_values = [20, 35, 50, 65, 80]
@ -673,7 +768,7 @@ def visualize_density_var(batch, z_q, models, device, tile_dict):
fixed_target_density[0, WALL_DENSITY_IDX] = wall_count / MAP_SIZE
target_density_cpu = fixed_target_density[0].cpu()
_, _, merged123 = full_generate_specific_z(
inp1_t, z_q[0:1], struct_t, fixed_target_density, models, device
inp1_t, z_q_single, struct_t, fixed_target_density, models, device
)
gen_imgs.append(annotate_labels(to_img(merged123), struct_cpu, target_density_cpu))
row1 = [to_img(ref_np)] + gen_imgs[:2]
@ -695,7 +790,8 @@ def validate(
density_stats: dict,
epoch: int
):
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models
quantizer1, quantizer2, quantizer3 = quantizers
# 切换为推理模式(关闭 Dropout / BatchNorm 统计更新)
for m in [vq1, vq2, vq3, mg1, mg2, mg3]:
@ -706,6 +802,7 @@ def validate(
loss2_total = torch.Tensor([0]).to(device)
loss3_total = torch.Tensor([0]).to(device)
commit_total = torch.Tensor([0]).to(device)
code_hits_total = torch.zeros(3, quantizer1.K, device=device)
density_metrics = {
1: {"mae": 0.0, "over": 0.0, "count": 0},
@ -736,27 +833,30 @@ def validate(
struct = batch["struct_inject"].to(device)
target_density = batch["target_density"].to(device)
# VQ 编码:各阶段独立编码后拼接、量化
# VQ 编码:各阶段独立编码并分别量化
z_e1 = vq1(enc1) # [B, L, d_z]
z_e2 = vq2(enc2)
z_e3 = vq3(enc3)
z_e_all = torch.cat([z_e1, z_e2, z_e3], dim=1) # [B, L*3, d_z]
z_q, _, commit_loss = quantizer(z_e_all) # [B, L*3, d_z]
z_q, commit_loss, code_hits = quantize_stage_latents(
quantizers, z_e1, z_e2, z_e3
)
z_q1, z_q2, z_q3 = z_q
remain1 = compute_remaining(inp1, target_density, 1)
remain2 = compute_remaining(inp2, target_density, 2)
remain3 = compute_remaining(inp3, target_density, 3)
# 三阶段 MaskGIT 推理(均以完整 z_q、struct 和动态 remain 为条件)
logits1 = mg1(inp1, z_q, struct, remain1)
logits2 = mg2(inp2, z_q, struct, remain2)
logits3 = mg3(inp3, z_q, struct, remain3)
# 三阶段 MaskGIT 推理:各阶段只接收自己的 z_q
logits1 = mg1(inp1, z_q1, struct, remain1)
logits2 = mg2(inp2, z_q2, struct, remain2)
logits3 = mg3(inp3, z_q3, struct, remain3)
loss1_total += focal_loss(logits1, target1)
loss2_total += focal_loss(logits2, target2)
loss3_total += focal_loss(logits3, target3)
loss1_total += cross_entropy_loss(logits1, target1)
loss2_total += cross_entropy_loss(logits2, target2)
loss3_total += cross_entropy_loss(logits3, target3)
commit_total += commit_loss
code_hits_total += code_hits
# 计算各目标对象的真实密度误差与过量生成密度
pred1_map = torch.argmax(logits1, dim=-1).cpu()
@ -806,19 +906,20 @@ def validate(
for m in [vq1, vq2, vq3, mg1, mg2, mg3]:
m.train()
return loss1_total, loss2_total, loss3_total, commit_total
return loss1_total, loss2_total, loss3_total, commit_total, code_hits_total
def train(device: torch.device):
args = parse_arguments()
models = build_model(device)
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler = models
quantizer1, quantizer2, quantizer3 = quantizers
tqdm.write(f"Device: {device}")
model_list = [
("vq1", vq1), ("vq2", vq2), ("vq3", vq3),
("mg1", mg1), ("mg2", mg2), ("mg3", mg3),
("quantizer", quantizer)
("quantizer1", quantizer1), ("quantizer2", quantizer2), ("quantizer3", quantizer3)
]
total_params = 0
for name, m in model_list:
@ -838,7 +939,15 @@ def train(device: torch.device):
mg1.load_state_dict(ckpt["mg1"])
mg2.load_state_dict(ckpt["mg2"])
mg3.load_state_dict(ckpt["mg3"])
quantizer.load_state_dict(ckpt["quantizer"])
if "quantizer1" in ckpt:
quantizer1.load_state_dict(ckpt["quantizer1"])
quantizer2.load_state_dict(ckpt["quantizer2"])
quantizer3.load_state_dict(ckpt["quantizer3"])
elif "quantizer" in ckpt:
quantizer1.load_state_dict(ckpt["quantizer"])
quantizer2.load_state_dict(ckpt["quantizer"])
quantizer3.load_state_dict(ckpt["quantizer"])
tqdm.write("Loaded legacy shared quantizer weights into quantizer1/2/3")
# load_optim=False 时可跳过优化器/调度器恢复(适合调整学习率后继续训练)
if args.load_optim and "optimizer" in ckpt:
optimizer.load_state_dict(ckpt["optimizer"])
@ -878,6 +987,7 @@ def train(device: torch.device):
loss2_total = torch.Tensor([0]).to(device)
loss3_total = torch.Tensor([0]).to(device)
commit_total = torch.Tensor([0]).to(device)
code_hits_total = torch.zeros(3, quantizer1.K, device=device)
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
# 三阶段各自的掩码输入序列、预测目标和编码器上下文
@ -904,26 +1014,39 @@ def train(device: torch.device):
z_e2 = vq2(enc2)
z_e3 = vq3(enc3)
# 合并三阶段编码后量化
z_e_all = torch.cat([z_e1, z_e2, z_e3], dim=1) # [B, L*3, d_z]
z_q, _, commit_loss = quantizer(z_e_all) # [B, L*3, d_z]
# 三阶段分别量化,各自使用独立 codebook
z_q, commit_loss, code_hits = quantize_stage_latents(
quantizers, z_e1, z_e2, z_e3
)
z_q1, z_q2, z_q3 = z_q
rollout_steps = build_reference_rollout_steps(REFERENCE_SAMPLE_PROB)
inp1 = sample_reference_inputs(
mg1, inp1, z_q1, struct, target_density, 1, rollout_steps
)
inp2 = sample_reference_inputs(
mg2, inp2, z_q2, struct, target_density, 2, rollout_steps
)
inp3 = sample_reference_inputs(
mg3, inp3, z_q3, struct, target_density, 3, rollout_steps
)
remain1 = compute_remaining(inp1, target_density, 1)
remain2 = compute_remaining(inp2, target_density, 2)
remain3 = compute_remaining(inp3, target_density, 3)
# 三阶段 MaskGIT 前向(均接收完整三阶段 z_q、struct 和动态 remain 条件)
logits1 = mg1(inp1, z_q, struct, remain1)
logits2 = mg2(inp2, z_q, struct, remain2)
logits3 = mg3(inp3, z_q, struct, remain3)
# 三阶段 MaskGIT 前向:各阶段只接收自己的 z_q、struct 和动态 remain 条件
logits1 = mg1(inp1, z_q1, struct, remain1)
logits2 = mg2(inp2, z_q2, struct, remain2)
logits3 = mg3(inp3, z_q3, struct, remain3)
# 三阶段 Focal Loss + VQ commit loss 加权求和
loss1 = focal_loss(logits1, target1)
loss2 = focal_loss(logits2, target2)
loss3 = focal_loss(logits3, target3)
loss1_weighted = STAGE1_FOCAL_WEIGHT * loss1
loss2_weighted = STAGE2_FOCAL_WEIGHT * loss2
loss3_weighted = STAGE3_FOCAL_WEIGHT * loss3
# 三阶段 Cross Entropy + VQ commit loss 加权求和
loss1 = cross_entropy_loss(logits1, target1)
loss2 = cross_entropy_loss(logits2, target2)
loss3 = cross_entropy_loss(logits3, target3)
loss1_weighted = STAGE1_CE_WEIGHT * loss1
loss2_weighted = STAGE2_CE_WEIGHT * loss2
loss3_weighted = STAGE3_CE_WEIGHT * loss3
commit_weighted = VQ_BETA * commit_loss
loss = loss1_weighted + loss2_weighted + loss3_weighted + commit_weighted
@ -936,11 +1059,13 @@ def train(device: torch.device):
loss2_total += loss2.detach()
loss3_total += loss3.detach()
commit_total += commit_loss.detach()
code_hits_total += code_hits.detach()
# 每个 epoch 结束后更新学习率
scheduler.step()
data_length = len(dataloader)
train_perplexity, train_usage_rate, train_active_codes = summarize_codebook_hits(code_hits_total)
tqdm.write(
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
f"E: {epoch + 1} | Loss: {loss_total.item() / data_length:.6f} | "
@ -948,20 +1073,23 @@ def train(device: torch.device):
f"L2: {loss2_total.item() / data_length:.6f} | "
f"L3: {loss3_total.item() / data_length:.6f} | "
f"VQ: {commit_total.item() / data_length:.6f} | "
f"PPL: {train_perplexity:.4f} | "
f"Usage: {train_usage_rate:.4f} ({train_active_codes}/{code_hits_total.numel()}) | "
f"LR: {scheduler.get_last_lr()[0]:.6f}"
)
# 每 CHECKPOINT 个 epoch 执行一次验证、可视化和检查点保存
if (epoch + 1) % CHECKPOINT == 0:
losses = validate(dataloader_val, models, device, tile_dict, dataset.density_stats, epoch + 1)
loss1_total, loss2_total, loss3_total, commit_total = losses
loss1_weighted = STAGE1_FOCAL_WEIGHT * loss1_total
loss2_weighted = STAGE2_FOCAL_WEIGHT * loss2_total
loss3_weighted = STAGE3_FOCAL_WEIGHT * loss3_total
loss1_total, loss2_total, loss3_total, commit_total, code_hits_total = losses
loss1_weighted = STAGE1_CE_WEIGHT * loss1_total
loss2_weighted = STAGE2_CE_WEIGHT * loss2_total
loss3_weighted = STAGE3_CE_WEIGHT * loss3_total
commit_weighted = VQ_BETA * commit_total
loss_total = loss1_weighted + loss2_weighted + loss3_weighted + commit_weighted
data_length = len(dataloader_val)
val_perplexity, val_usage_rate, val_active_codes = summarize_codebook_hits(code_hits_total)
tqdm.write(
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
f"E: {epoch + 1} | Loss: {loss_total.item() / data_length:.6f} | "
@ -969,6 +1097,8 @@ def train(device: torch.device):
f"L2: {loss2_total.item() / data_length:.6f} | "
f"L3: {loss3_total.item() / data_length:.6f} | "
f"VQ: {commit_total.item() / data_length:.6f} | "
f"PPL: {val_perplexity:.4f} | "
f"Usage: {val_usage_rate:.4f} ({val_active_codes}/{code_hits_total.numel()}) | "
)
ckpt_path = f"result/seperated/sep-{epoch + 1}.pth"
@ -980,7 +1110,9 @@ def train(device: torch.device):
"mg1": mg1.state_dict(),
"mg2": mg2.state_dict(),
"mg3": mg3.state_dict(),
"quantizer": quantizer.state_dict(),
"quantizer1": quantizer1.state_dict(),
"quantizer2": quantizer2.state_dict(),
"quantizer3": quantizer3.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
}, ckpt_path)
@ -996,7 +1128,9 @@ def train(device: torch.device):
"mg1": mg1.state_dict(),
"mg2": mg2.state_dict(),
"mg3": mg3.state_dict(),
"quantizer": quantizer.state_dict(),
"quantizer1": quantizer1.state_dict(),
"quantizer2": quantizer2.state_dict(),
"quantizer3": quantizer3.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
}, final_path)

View File

@ -18,8 +18,23 @@ class VectorQuantizer(nn.Module):
self.codebook = nn.Embedding(K, d_z)
nn.init.uniform_(self.codebook.weight, -1.0 / K, 1.0 / K)
def forward(self, z_e: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# z_e: [B, L * 3, d_z]
def codebook_stats(
self, indices: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
flat_indices = indices.reshape(-1)
one_hot = F.one_hot(flat_indices, num_classes=self.K).float()
avg_probs = one_hot.mean(dim=0)
perplexity = torch.exp(
-(avg_probs * torch.log(avg_probs.clamp_min(1e-10))).sum()
)
usage_rate = (avg_probs > 0).float().mean()
usage_count = one_hot.sum(dim=0)
return perplexity, usage_rate, usage_count
def forward(
self, z_e: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# z_e: [B, L, d_z]
"""
Args:
z_e: [B, L, d_z] 编码器输出的连续向量序列
@ -31,7 +46,7 @@ class VectorQuantizer(nn.Module):
"""
B, L, d_z = z_e.shape
z_flat = z_e.reshape(B * L, d_z) # [B * L * 3, d_z]
z_flat = z_e.reshape(B * L, d_z) # [B * L, d_z]
codebook_w = self.codebook.weight # [K, d_z]
@ -56,13 +71,9 @@ class VectorQuantizer(nn.Module):
commit_loss = F.mse_loss(z_e, z_q.detach())
indices = indices.reshape(B, L)
return z_q_st, indices, commit_loss
perplexity, usage_rate, usage_count = self.codebook_stats(indices)
return z_q_st, indices, commit_loss, perplexity, usage_count
def sample(self, B: int, L: int, device: torch.device) -> torch.Tensor:
indices1 = torch.randint(0, self.K, (B, L), device=device)
indices2 = torch.randint(0, self.K, (B, L), device=device)
indices3 = torch.randint(0, self.K, (B, L), device=device)
z1 = self.codebook(indices1)
z2 = self.codebook(indices2)
z3 = self.codebook(indices3)
return torch.cat([z1, z2, z3], dim=1)
indices = torch.randint(0, self.K, (B, L), device=device)
return self.codebook(indices)