diff --git a/ginka/train_vae.py b/ginka/train_vae.py index b529c71..7ec75e4 100644 --- a/ginka/train_vae.py +++ b/ginka/train_vae.py @@ -54,7 +54,7 @@ from shared.image import matrix_to_image_cv BATCH_SIZE = 128 LATENT_DIM = 48 -KL_BETA = 0.05 +KL_BETA = 0.1 device = torch.device( "cuda:1" if torch.cuda.is_available() @@ -145,21 +145,22 @@ def train(): ) # 验证集 - with torch.no_grad(): - val_loss_total = torch.Tensor([0]).to(device) - for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): - target_map = batch["target_map"].to(device) + # with torch.no_grad(): + # val_loss_total = torch.Tensor([0]).to(device) + # for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): + # target_map = batch["target_map"].to(device) - fake_logits, mu, logvar = vae(target_map, 1 - gt_prob) + # fake_logits, mu, logvar = vae(target_map, 1 - gt_prob) - loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA) - val_loss_total += loss.detach() + # loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA) + # val_loss_total += loss.detach() - avg_loss_val = val_loss_total.item() / len(dataloader_val) - if avg_loss_val < 0.5 and gt_prob > 0: + # avg_loss_val = val_loss_total.item() / len(dataloader_val) + # 先使用训练集的损失值,因为过拟合比较严重,后续再想办法 + if avg_loss < 0.5 and gt_prob > 0: gt_prob -= 0.01 - scheduler_ginka.step(avg_loss_val) + scheduler_ginka.step(avg_loss) # 每若干轮输出一次图片,并保存检查点 if (epoch + 1) % args.checkpoint == 0: