From dc6d1c69beb82a6d0f9086d920bbd3295b6236ef Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Fri, 6 Feb 2026 14:35:31 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E9=AA=8C=E8=AF=81=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_vae.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ginka/train_vae.py b/ginka/train_vae.py index 9956e55..a13adcb 100644 --- a/ginka/train_vae.py +++ b/ginka/train_vae.py @@ -142,15 +142,14 @@ def train(): # 验证集 with torch.no_grad(): + val_loss_total = torch.Tensor([0]).to(device) for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): target_map = batch["target_map"].to(device) fake_logits, mu, logvar = vae(target_map, 1 - gt_prob) - loss = criterion.vae_loss(fake_logits, target_map) + loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA) val_loss_total += loss.detach() - val_reco_loss_total += loss.detach() - val_kl_loss_total += loss.detach() idx += 1