mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
20 lines
628 B
Python
20 lines
628 B
Python
import torch
|
|
import torch.nn.functional as F
|
|
|
|
class VAELoss:
|
|
def __init__(self):
|
|
self.num_classes = 32
|
|
|
|
def vae_loss(self, logits, target, mu, logvar, beta=0.1):
|
|
# target: [B, 169]
|
|
end_token = torch.tensor([15], dtype=torch.long).to(logits.device)
|
|
target = torch.cat([target, end_token], dim=1)
|
|
target = F.one_hot(target, num_classes=self.num_classes).float()
|
|
recon_loss = F.cross_entropy(logits, target)
|
|
|
|
kl_loss = -0.5 * torch.mean(
|
|
1 + logvar - mu.pow(2) - logvar.exp()
|
|
)
|
|
|
|
return recon_loss + beta * kl_loss, recon_loss, kl_loss
|