diff --git a/ginka/vae_rnn/vae.py b/ginka/vae_rnn/vae.py index 95da9ad..e183f6c 100644 --- a/ginka/vae_rnn/vae.py +++ b/ginka/vae_rnn/vae.py @@ -7,7 +7,7 @@ from .decoder import VAEDecoder class GinkaVAE(nn.Module): def __init__(self, device, tile_classes=32, latent_dim=32): super().__init__() - self.encoder = VAEEncoder(tile_classes, latent_dim) + self.encoder = VAEEncoder(device, tile_classes, latent_dim) self.decoder = VAEDecoder(device) def reparameterize(self, mu, logvar): @@ -16,8 +16,7 @@ class GinkaVAE(nn.Module): return mu + eps * std def forward(self, target_map: torch.Tensor, use_self_probility=0): - target = F.one_hot(target_map, num_classes=32).float().permute(0, 3, 1, 2) - mu, logvar = self.encoder(target) + mu, logvar = self.encoder(target_map) z = self.reparameterize(mu, logvar) logits = self.decoder(z, target_map, use_self_probility) return logits, mu, logvar \ No newline at end of file