From f22f06ef2d653ca1569daa3b42e4c56153c62a03 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Mon, 19 Jan 2026 23:54:11 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=8F=92=E5=80=BC=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E5=90=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_vae.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ginka/train_vae.py b/ginka/train_vae.py index dfb07f4..74844a7 100644 --- a/ginka/train_vae.py +++ b/ginka/train_vae.py @@ -196,6 +196,7 @@ def train(): z2 = vae.reparameterize(mu2, logvar2) real_img1 = matrix_to_image_cv(map1[0], tile_dict) real_img2 = matrix_to_image_cv(map2[0], tile_dict) + i = 0 for t in torch.linspace(0, 1, 8): z = z1 * (1 - t / 8) + z2 * t / 8 fake_logits = vae.decoder(z, torch.zeros(1, 13, 13).to(device), 1) @@ -203,7 +204,8 @@ def train(): fake_img = matrix_to_image_cv(fake_map[0], tile_dict) img = np.block([[real_img1], [vline], [fake_img], [vline], [real_img2]]) - cv2.imwrite(f"result/ginka_vae_img/{t}_linspace.png", img) + cv2.imwrite(f"result/ginka_vae_img/{i}_linspace.png", img) + i += 1 avg_loss_val = val_loss_total.item() / len(dataloader_val) avg_reco_loss = reco_loss_total.item() / len(dataloader_val)