fix: MaskGIT 输出地图不更新编码器

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
unanmed 2026-04-28 15:53:41 +08:00
parent 6b2b36f20c
commit 789107969b

View File

@ -625,14 +625,21 @@ def train():
# 4. z 一致性约束(方案 A将 MaskGIT 的 logits 经温度平滑后
# 与 VQ 编码器的 tile embedding 做加权求和,得到软嵌入序列,
# 再送入编码器得到 z_pred_e约束其与真实 z_e 对齐。
# 梯度从 z_pred_e 回传到 MaskGIT 的 logits以及 VQ encoder 的权重);
# z_e 作为 detach 后的监督目标,不产生梯度。
# 梯度从 z_pred_e 回传到 MaskGIT 的 logits
# VQ 参数在此路径上临时冻结requires_grad=False
# 确保编码器权重仅由真实地图路径vq_loss更新不被一致性损失带偏。
for p in model_vq.parameters():
p.requires_grad_(False)
soft_probs = F.softmax(logits / CONSIST_TEMP, dim=-1) # [B, H*W, V]
tile_emb = model_vq.tile_embedding.weight # [V, d_model]
soft_emb = soft_probs @ tile_emb # [B, H*W, d_model]
z_pred_e = model_vq.encode_soft(soft_emb) # [B, L, d_z]
consist_loss = F.mse_loss(z_pred_e, z_e.detach())
for p in model_vq.parameters():
p.requires_grad_(True)
# 5. 联合损失
loss = masked_ce + vq_loss + CONSIST_LAMBDA * consist_loss