mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-17 23:21:20 +08:00
14 lines
363 B
Python
14 lines
363 B
Python
import torch
|
|
import torch.nn.functional as F
|
|
|
|
class VAELoss:
|
|
def __init__(self):
|
|
self.num_classes = 32
|
|
|
|
def vae_loss(self, logits, target):
|
|
# target: [B, 13, 13]
|
|
target = F.one_hot(target, num_classes=self.num_classes).float().permute(0, 3, 1, 2)
|
|
recon_loss = F.cross_entropy(logits, target)
|
|
|
|
return recon_loss
|