mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
fix: 验证报错
This commit is contained in:
parent
4d244d021a
commit
8caa37a144
@ -192,10 +192,8 @@ def train():
|
||||
index2 = random.randint(0, val_length - 1)
|
||||
map1 = torch.LongTensor(dataset_val.data[index1]["map"]).to(device).reshape(1, 13, 13)
|
||||
map2 = torch.LongTensor(dataset_val.data[index2]["map"]).to(device).reshape(1, 13, 13)
|
||||
map1_onehot = F.one_hot(map1, 32).permute(0, 3, 1, 2).float().to(device)
|
||||
map2_onehot = F.one_hot(map2, 32).permute(0, 3, 1, 2).float().to(device)
|
||||
mu1, logvar1 = vae.encoder(map1_onehot)
|
||||
mu2, logvar2 = vae.encoder(map2_onehot)
|
||||
mu1, logvar1 = vae.encoder(map1)
|
||||
mu2, logvar2 = vae.encoder(map2)
|
||||
z1 = vae.reparameterize(mu1, logvar1)
|
||||
z2 = vae.reparameterize(mu2, logvar2)
|
||||
real_img1 = matrix_to_image_cv(map1[0].cpu().numpy(), tile_dict)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user