fix: 插值文件名

This commit is contained in:
unanmed 2026-01-19 23:54:11 +08:00
parent 83d6c31704
commit f22f06ef2d

View File

@ -196,6 +196,7 @@ def train():
z2 = vae.reparameterize(mu2, logvar2)
real_img1 = matrix_to_image_cv(map1[0], tile_dict)
real_img2 = matrix_to_image_cv(map2[0], tile_dict)
i = 0
for t in torch.linspace(0, 1, 8):
z = z1 * (1 - t / 8) + z2 * t / 8
fake_logits = vae.decoder(z, torch.zeros(1, 13, 13).to(device), 1)
@ -203,7 +204,8 @@ def train():
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
img = np.block([[real_img1], [vline], [fake_img], [vline], [real_img2]])
cv2.imwrite(f"result/ginka_vae_img/{t}_linspace.png", img)
cv2.imwrite(f"result/ginka_vae_img/{i}_linspace.png", img)
i += 1
avg_loss_val = val_loss_total.item() / len(dataloader_val)
avg_reco_loss = reco_loss_total.item() / len(dataloader_val)