mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-20 17:22:41 +08:00
fix: 验证报错
This commit is contained in:
parent
948132797d
commit
dc6d1c69be
@ -142,15 +142,14 @@ def train():
|
|||||||
|
|
||||||
# 验证集
|
# 验证集
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
val_loss_total = torch.Tensor([0]).to(device)
|
||||||
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
||||||
target_map = batch["target_map"].to(device)
|
target_map = batch["target_map"].to(device)
|
||||||
|
|
||||||
fake_logits, mu, logvar = vae(target_map, 1 - gt_prob)
|
fake_logits, mu, logvar = vae(target_map, 1 - gt_prob)
|
||||||
|
|
||||||
loss = criterion.vae_loss(fake_logits, target_map)
|
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
|
||||||
val_loss_total += loss.detach()
|
val_loss_total += loss.detach()
|
||||||
val_reco_loss_total += loss.detach()
|
|
||||||
val_kl_loss_total += loss.detach()
|
|
||||||
|
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user