diff --git a/ginka/train_joint.py b/ginka/train_joint.py index e709fd7..a1c9db4 100644 --- a/ginka/train_joint.py +++ b/ginka/train_joint.py @@ -107,7 +107,7 @@ def predict_x0(diffusion: Diffusion, x_t: torch.Tensor, pred_noise: torch.Tensor sqrt_ab = diffusion.sqrt_ab[t][:, None, None, None] sqrt_one_minus_ab = diffusion.sqrt_one_minus_ab[t][:, None, None, None] x0 = (x_t - sqrt_one_minus_ab * pred_noise) / sqrt_ab - return torch.clamp(x0, 0.0, 1.0) + return x0 def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: torch.Tensor): @@ -127,10 +127,8 @@ def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: tor ce = F.cross_entropy( logits.permute(0, 2, 1), target_tokens, - reduction='none', label_smoothing=LABEL_SMOOTHING ) - ce = (ce * current_mask).sum() / (current_mask.sum() + 1e-6) losses.append(ce) with torch.no_grad():