From c8809b8ee74f72d44cb5c5f8758ce202ff37d416 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Fri, 13 Feb 2026 00:05:38 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E8=B0=83=E6=95=B4=E8=B0=83=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_vae.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ginka/train_vae.py b/ginka/train_vae.py index 9965ac9..238321f 100644 --- a/ginka/train_vae.py +++ b/ginka/train_vae.py @@ -58,7 +58,8 @@ LATENT_DIM = 48 KL_BETA = 0.1 SELF_GATE = 0.5 GATE_EPOCH = 5 -VAL_BATCH_DIVIDER = 1 +VAL_BATCH_DIVIDER = 128 +PROB_STEP = 0.05 device = torch.device( "cuda:1" if torch.cuda.is_available() @@ -168,10 +169,14 @@ def train(): # 先使用训练集的损失值,因为过拟合比较严重,后续再想办法 if avg_loss < SELF_GATE: prob_epochs += 1 + else: + prob_epochs = 0 if prob_epochs >= GATE_EPOCH and self_prob < 1: - self_prob += 0.01 + self_prob += PROB_STEP prob_epochs = 0 + if self_prob > 1: + self_prob = 1 scheduler_ginka.step(avg_loss, self_prob)