From 3be014f3adf8847289827b10fa9eae56dcc5d4ee Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Thu, 23 Apr 2026 19:24:17 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E8=B0=83=E6=95=B4=E6=8D=9F=E5=A4=B1?= =?UTF-8?q?=E5=80=BC=E8=AE=A1=E7=AE=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_joint.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ginka/train_joint.py b/ginka/train_joint.py index e709fd7..a1c9db4 100644 --- a/ginka/train_joint.py +++ b/ginka/train_joint.py @@ -107,7 +107,7 @@ def predict_x0(diffusion: Diffusion, x_t: torch.Tensor, pred_noise: torch.Tensor sqrt_ab = diffusion.sqrt_ab[t][:, None, None, None] sqrt_one_minus_ab = diffusion.sqrt_one_minus_ab[t][:, None, None, None] x0 = (x_t - sqrt_one_minus_ab * pred_noise) / sqrt_ab - return torch.clamp(x0, 0.0, 1.0) + return x0 def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: torch.Tensor): @@ -127,10 +127,8 @@ def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: tor ce = F.cross_entropy( logits.permute(0, 2, 1), target_tokens, - reduction='none', label_smoothing=LABEL_SMOOTHING ) - ce = (ce * current_mask).sum() / (current_mask.sum() + 1e-6) losses.append(ce) with torch.no_grad():