mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
fix: MaskGIT 输出地图不更新编码器
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
6b2b36f20c
commit
789107969b
@ -625,14 +625,21 @@ def train():
|
||||
# 4. z 一致性约束(方案 A):将 MaskGIT 的 logits 经温度平滑后
|
||||
# 与 VQ 编码器的 tile embedding 做加权求和,得到软嵌入序列,
|
||||
# 再送入编码器得到 z_pred_e,约束其与真实 z_e 对齐。
|
||||
# 梯度从 z_pred_e 回传到 MaskGIT 的 logits(以及 VQ encoder 的权重);
|
||||
# z_e 作为 detach 后的监督目标,不产生梯度。
|
||||
# 梯度从 z_pred_e 回传到 MaskGIT 的 logits;
|
||||
# VQ 参数在此路径上临时冻结(requires_grad=False),
|
||||
# 确保编码器权重仅由真实地图路径(vq_loss)更新,不被一致性损失带偏。
|
||||
for p in model_vq.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
soft_probs = F.softmax(logits / CONSIST_TEMP, dim=-1) # [B, H*W, V]
|
||||
tile_emb = model_vq.tile_embedding.weight # [V, d_model]
|
||||
soft_emb = soft_probs @ tile_emb # [B, H*W, d_model]
|
||||
z_pred_e = model_vq.encode_soft(soft_emb) # [B, L, d_z]
|
||||
consist_loss = F.mse_loss(z_pred_e, z_e.detach())
|
||||
|
||||
for p in model_vq.parameters():
|
||||
p.requires_grad_(True)
|
||||
|
||||
# 5. 联合损失
|
||||
loss = masked_ce + vq_loss + CONSIST_LAMBDA * consist_loss
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user