From 789107969b5ce345a8516f79db1caa0e18c1f4f2 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Tue, 28 Apr 2026 15:53:41 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20MaskGIT=20=E8=BE=93=E5=87=BA=E5=9C=B0?= =?UTF-8?q?=E5=9B=BE=E4=B8=8D=E6=9B=B4=E6=96=B0=E7=BC=96=E7=A0=81=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Copilot --- ginka/train_vq.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/ginka/train_vq.py b/ginka/train_vq.py index 6d9eb37..fbf42b9 100644 --- a/ginka/train_vq.py +++ b/ginka/train_vq.py @@ -625,14 +625,21 @@ def train(): # 4. z 一致性约束(方案 A):将 MaskGIT 的 logits 经温度平滑后 # 与 VQ 编码器的 tile embedding 做加权求和,得到软嵌入序列, # 再送入编码器得到 z_pred_e,约束其与真实 z_e 对齐。 - # 梯度从 z_pred_e 回传到 MaskGIT 的 logits(以及 VQ encoder 的权重); - # z_e 作为 detach 后的监督目标,不产生梯度。 + # 梯度从 z_pred_e 回传到 MaskGIT 的 logits; + # VQ 参数在此路径上临时冻结(requires_grad=False), + # 确保编码器权重仅由真实地图路径(vq_loss)更新,不被一致性损失带偏。 + for p in model_vq.parameters(): + p.requires_grad_(False) + soft_probs = F.softmax(logits / CONSIST_TEMP, dim=-1) # [B, H*W, V] tile_emb = model_vq.tile_embedding.weight # [V, d_model] soft_emb = soft_probs @ tile_emb # [B, H*W, d_model] z_pred_e = model_vq.encode_soft(soft_emb) # [B, L, d_z] consist_loss = F.mse_loss(z_pred_e, z_e.detach()) + for p in model_vq.parameters(): + p.requires_grad_(True) + # 5. 联合损失 loss = masked_ce + vq_loss + CONSIST_LAMBDA * consist_loss