chore: 调整损失值计算

This commit is contained in:
unanmed 2026-04-23 19:24:17 +08:00
parent 164bb24823
commit 3be014f3ad

View File

@ -107,7 +107,7 @@ def predict_x0(diffusion: Diffusion, x_t: torch.Tensor, pred_noise: torch.Tensor
sqrt_ab = diffusion.sqrt_ab[t][:, None, None, None] sqrt_ab = diffusion.sqrt_ab[t][:, None, None, None]
sqrt_one_minus_ab = diffusion.sqrt_one_minus_ab[t][:, None, None, None] sqrt_one_minus_ab = diffusion.sqrt_one_minus_ab[t][:, None, None, None]
x0 = (x_t - sqrt_one_minus_ab * pred_noise) / sqrt_ab x0 = (x_t - sqrt_one_minus_ab * pred_noise) / sqrt_ab
return torch.clamp(x0, 0.0, 1.0) return x0
def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: torch.Tensor): def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: torch.Tensor):
@ -127,10 +127,8 @@ def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: tor
ce = F.cross_entropy( ce = F.cross_entropy(
logits.permute(0, 2, 1), logits.permute(0, 2, 1),
target_tokens, target_tokens,
reduction='none',
label_smoothing=LABEL_SMOOTHING label_smoothing=LABEL_SMOOTHING
) )
ce = (ce * current_mask).sum() / (current_mask.sum() + 1e-6)
losses.append(ce) losses.append(ce)
with torch.no_grad(): with torch.no_grad():