mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
chore: 调整损失值计算
This commit is contained in:
parent
164bb24823
commit
3be014f3ad
@ -107,7 +107,7 @@ def predict_x0(diffusion: Diffusion, x_t: torch.Tensor, pred_noise: torch.Tensor
|
|||||||
sqrt_ab = diffusion.sqrt_ab[t][:, None, None, None]
|
sqrt_ab = diffusion.sqrt_ab[t][:, None, None, None]
|
||||||
sqrt_one_minus_ab = diffusion.sqrt_one_minus_ab[t][:, None, None, None]
|
sqrt_one_minus_ab = diffusion.sqrt_one_minus_ab[t][:, None, None, None]
|
||||||
x0 = (x_t - sqrt_one_minus_ab * pred_noise) / sqrt_ab
|
x0 = (x_t - sqrt_one_minus_ab * pred_noise) / sqrt_ab
|
||||||
return torch.clamp(x0, 0.0, 1.0)
|
return x0
|
||||||
|
|
||||||
|
|
||||||
def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: torch.Tensor):
|
def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: torch.Tensor):
|
||||||
@ -127,10 +127,8 @@ def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: tor
|
|||||||
ce = F.cross_entropy(
|
ce = F.cross_entropy(
|
||||||
logits.permute(0, 2, 1),
|
logits.permute(0, 2, 1),
|
||||||
target_tokens,
|
target_tokens,
|
||||||
reduction='none',
|
|
||||||
label_smoothing=LABEL_SMOOTHING
|
label_smoothing=LABEL_SMOOTHING
|
||||||
)
|
)
|
||||||
ce = (ce * current_mask).sum() / (current_mask.sum() + 1e-6)
|
|
||||||
losses.append(ce)
|
losses.append(ce)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user