diff --git a/ginka/train_vq.py b/ginka/train_vq.py index 6d9eb37..fbf42b9 100644 --- a/ginka/train_vq.py +++ b/ginka/train_vq.py @@ -625,14 +625,21 @@ def train(): # 4. z 一致性约束(方案 A):将 MaskGIT 的 logits 经温度平滑后 # 与 VQ 编码器的 tile embedding 做加权求和,得到软嵌入序列, # 再送入编码器得到 z_pred_e,约束其与真实 z_e 对齐。 - # 梯度从 z_pred_e 回传到 MaskGIT 的 logits(以及 VQ encoder 的权重); - # z_e 作为 detach 后的监督目标,不产生梯度。 + # 梯度从 z_pred_e 回传到 MaskGIT 的 logits; + # VQ 参数在此路径上临时冻结(requires_grad=False), + # 确保编码器权重仅由真实地图路径(vq_loss)更新,不被一致性损失带偏。 + for p in model_vq.parameters(): + p.requires_grad_(False) + soft_probs = F.softmax(logits / CONSIST_TEMP, dim=-1) # [B, H*W, V] tile_emb = model_vq.tile_embedding.weight # [V, d_model] soft_emb = soft_probs @ tile_emb # [B, H*W, d_model] z_pred_e = model_vq.encode_soft(soft_emb) # [B, L, d_z] consist_loss = F.mse_loss(z_pred_e, z_e.detach()) + for p in model_vq.parameters(): + p.requires_grad_(True) + # 5. 联合损失 loss = masked_ce + vq_loss + CONSIST_LAMBDA * consist_loss