mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-23 03:51:11 +08:00
fix: vae参数错误
This commit is contained in:
parent
dd6a043487
commit
14f391f4f4
@ -7,7 +7,7 @@ from .decoder import VAEDecoder
|
|||||||
class GinkaVAE(nn.Module):
|
class GinkaVAE(nn.Module):
|
||||||
def __init__(self, device, tile_classes=32, latent_dim=32):
|
def __init__(self, device, tile_classes=32, latent_dim=32):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder = VAEEncoder(tile_classes, latent_dim)
|
self.encoder = VAEEncoder(device, tile_classes, latent_dim)
|
||||||
self.decoder = VAEDecoder(device)
|
self.decoder = VAEDecoder(device)
|
||||||
|
|
||||||
def reparameterize(self, mu, logvar):
|
def reparameterize(self, mu, logvar):
|
||||||
@ -16,8 +16,7 @@ class GinkaVAE(nn.Module):
|
|||||||
return mu + eps * std
|
return mu + eps * std
|
||||||
|
|
||||||
def forward(self, target_map: torch.Tensor, use_self_probility=0):
|
def forward(self, target_map: torch.Tensor, use_self_probility=0):
|
||||||
target = F.one_hot(target_map, num_classes=32).float().permute(0, 3, 1, 2)
|
mu, logvar = self.encoder(target_map)
|
||||||
mu, logvar = self.encoder(target)
|
|
||||||
z = self.reparameterize(mu, logvar)
|
z = self.reparameterize(mu, logvar)
|
||||||
logits = self.decoder(z, target_map, use_self_probility)
|
logits = self.decoder(z, target_map, use_self_probility)
|
||||||
return logits, mu, logvar
|
return logits, mu, logvar
|
||||||
Loading…
Reference in New Issue
Block a user