diff --git a/ginka/train_vae.py b/ginka/train_vae.py index 1589a8a..df72c47 100644 --- a/ginka/train_vae.py +++ b/ginka/train_vae.py @@ -192,10 +192,8 @@ def train(): index2 = random.randint(0, val_length - 1) map1 = torch.LongTensor(dataset_val.data[index1]["map"]).to(device).reshape(1, 13, 13) map2 = torch.LongTensor(dataset_val.data[index2]["map"]).to(device).reshape(1, 13, 13) - map1_onehot = F.one_hot(map1, 32).permute(0, 3, 1, 2).float().to(device) - map2_onehot = F.one_hot(map2, 32).permute(0, 3, 1, 2).float().to(device) - mu1, logvar1 = vae.encoder(map1_onehot) - mu2, logvar2 = vae.encoder(map2_onehot) + mu1, logvar1 = vae.encoder(map1) + mu2, logvar2 = vae.encoder(map2) z1 = vae.reparameterize(mu1, logvar1) z2 = vae.reparameterize(mu2, logvar2) real_img1 = matrix_to_image_cv(map1[0].cpu().numpy(), tile_dict)