From 14f391f4f43b94fe8537dfad9dd22329622197f6 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Tue, 20 Jan 2026 16:43:59 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20vae=E5=8F=82=E6=95=B0=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/vae_rnn/vae.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ginka/vae_rnn/vae.py b/ginka/vae_rnn/vae.py index 95da9ad..e183f6c 100644 --- a/ginka/vae_rnn/vae.py +++ b/ginka/vae_rnn/vae.py @@ -7,7 +7,7 @@ from .decoder import VAEDecoder class GinkaVAE(nn.Module): def __init__(self, device, tile_classes=32, latent_dim=32): super().__init__() - self.encoder = VAEEncoder(tile_classes, latent_dim) + self.encoder = VAEEncoder(device, tile_classes, latent_dim) self.decoder = VAEDecoder(device) def reparameterize(self, mu, logvar): @@ -16,8 +16,7 @@ class GinkaVAE(nn.Module): return mu + eps * std def forward(self, target_map: torch.Tensor, use_self_probility=0): - target = F.one_hot(target_map, num_classes=32).float().permute(0, 3, 1, 2) - mu, logvar = self.encoder(target) + mu, logvar = self.encoder(target_map) z = self.reparameterize(mu, logvar) logits = self.decoder(z, target_map, use_self_probility) return logits, mu, logvar \ No newline at end of file