From 8caa37a14436722c1c9fc6cef8ef982d1ee00799 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Tue, 20 Jan 2026 19:05:00 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E9=AA=8C=E8=AF=81=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_vae.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/ginka/train_vae.py b/ginka/train_vae.py index 1589a8a..df72c47 100644 --- a/ginka/train_vae.py +++ b/ginka/train_vae.py @@ -192,10 +192,8 @@ def train(): index2 = random.randint(0, val_length - 1) map1 = torch.LongTensor(dataset_val.data[index1]["map"]).to(device).reshape(1, 13, 13) map2 = torch.LongTensor(dataset_val.data[index2]["map"]).to(device).reshape(1, 13, 13) - map1_onehot = F.one_hot(map1, 32).permute(0, 3, 1, 2).float().to(device) - map2_onehot = F.one_hot(map2, 32).permute(0, 3, 1, 2).float().to(device) - mu1, logvar1 = vae.encoder(map1_onehot) - mu2, logvar2 = vae.encoder(map2_onehot) + mu1, logvar1 = vae.encoder(map1) + mu2, logvar2 = vae.encoder(map2) z1 = vae.reparameterize(mu1, logvar1) z2 = vae.reparameterize(mu2, logvar2) real_img1 = matrix_to_image_cv(map1[0].cpu().numpy(), tile_dict)