fix: 验证报错

This commit is contained in:
unanmed 2026-02-06 14:35:31 +08:00
parent 948132797d
commit dc6d1c69be

View File

@ -142,15 +142,14 @@ def train():
# 验证集
with torch.no_grad():
val_loss_total = torch.Tensor([0]).to(device)
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
target_map = batch["target_map"].to(device)
fake_logits, mu, logvar = vae(target_map, 1 - gt_prob)
loss = criterion.vae_loss(fake_logits, target_map)
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
val_loss_total += loss.detach()
val_reco_loss_total += loss.detach()
val_kl_loss_total += loss.detach()
idx += 1