From a07d2cf9602c741b85fc1db278b38322d901fe3b Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Thu, 12 Feb 2026 23:41:45 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E8=B0=83=E6=95=B4=E8=B0=83=E5=BA=A6?= =?UTF-8?q?=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_vae.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/ginka/train_vae.py b/ginka/train_vae.py index e440ac8..53bbf9b 100644 --- a/ginka/train_vae.py +++ b/ginka/train_vae.py @@ -56,6 +56,8 @@ from shared.image import matrix_to_image_cv BATCH_SIZE = 128 LATENT_DIM = 48 KL_BETA = 0.1 +SELF_GATE = 0.5 +GATE_EPOCH = 5 device = torch.device( "cuda:1" if torch.cuda.is_available() @@ -101,6 +103,7 @@ def train(): criterion = VAELoss() self_prob = 0 + prob_epochs = 0 # 用于生成图片 tile_dict = dict() @@ -153,16 +156,21 @@ def train(): # val_loss_total = torch.Tensor([0]).to(device) # for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): # target_map = batch["target_map"].to(device) - + # fake_logits, mu, logvar = vae(target_map, 1 - gt_prob) # loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA) # val_loss_total += loss.detach() # avg_loss_val = val_loss_total.item() / len(dataloader_val) + # 先使用训练集的损失值,因为过拟合比较严重,后续再想办法 - if avg_loss < 0.5 and self_prob < 1: + if avg_loss < SELF_GATE: + gate_epochs += 1 + + if gate_epochs >= GATE_EPOCH and self_prob < 1: self_prob += 0.01 + gate_epochs = 0 scheduler_ginka.step(avg_loss, self_prob)