mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
chore: 调整损失值计算
This commit is contained in:
parent
164bb24823
commit
3be014f3ad
@ -107,7 +107,7 @@ def predict_x0(diffusion: Diffusion, x_t: torch.Tensor, pred_noise: torch.Tensor
|
||||
sqrt_ab = diffusion.sqrt_ab[t][:, None, None, None]
|
||||
sqrt_one_minus_ab = diffusion.sqrt_one_minus_ab[t][:, None, None, None]
|
||||
x0 = (x_t - sqrt_one_minus_ab * pred_noise) / sqrt_ab
|
||||
return torch.clamp(x0, 0.0, 1.0)
|
||||
return x0
|
||||
|
||||
|
||||
def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: torch.Tensor):
|
||||
@ -127,10 +127,8 @@ def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: tor
|
||||
ce = F.cross_entropy(
|
||||
logits.permute(0, 2, 1),
|
||||
target_tokens,
|
||||
reduction='none',
|
||||
label_smoothing=LABEL_SMOOTHING
|
||||
)
|
||||
ce = (ce * current_mask).sum() / (current_mask.sum() + 1e-6)
|
||||
losses.append(ce)
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user