fix: 验证报错

This commit is contained in:
unanmed 2026-01-20 19:05:00 +08:00
parent 4d244d021a
commit 8caa37a144

View File

@ -192,10 +192,8 @@ def train():
index2 = random.randint(0, val_length - 1)
map1 = torch.LongTensor(dataset_val.data[index1]["map"]).to(device).reshape(1, 13, 13)
map2 = torch.LongTensor(dataset_val.data[index2]["map"]).to(device).reshape(1, 13, 13)
map1_onehot = F.one_hot(map1, 32).permute(0, 3, 1, 2).float().to(device)
map2_onehot = F.one_hot(map2, 32).permute(0, 3, 1, 2).float().to(device)
mu1, logvar1 = vae.encoder(map1_onehot)
mu2, logvar2 = vae.encoder(map2_onehot)
mu1, logvar1 = vae.encoder(map1)
mu2, logvar2 = vae.encoder(map2)
z1 = vae.reparameterize(mu1, logvar1)
z2 = vae.reparameterize(mu2, logvar2)
real_img1 = matrix_to_image_cv(map1[0].cpu().numpy(), tile_dict)