From becf625bdb21b577d5f947eafc4aff7af5428c79 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Thu, 12 Feb 2026 23:50:08 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E8=B0=83=E6=95=B4=E9=83=A8=E5=88=86?= =?UTF-8?q?=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_vae.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ginka/train_vae.py b/ginka/train_vae.py index 53bbf9b..9965ac9 100644 --- a/ginka/train_vae.py +++ b/ginka/train_vae.py @@ -58,6 +58,7 @@ LATENT_DIM = 48 KL_BETA = 0.1 SELF_GATE = 0.5 GATE_EPOCH = 5 +VAL_BATCH_DIVIDER = 1 device = torch.device( "cuda:1" if torch.cuda.is_available() @@ -92,12 +93,12 @@ def train(): dataset = GinkaRNNDataset(args.train, device) dataset_val = GinkaRNNDataset(args.validate, device) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) - dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 64, shuffle=True) + dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True) optimizer_ginka = optim.AdamW(vae.parameters(), lr=2e-4, weight_decay=1e-4) # 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习 scheduler_ginka = VAEScheduler( - optimizer_ginka, factor=0.9, increase_factor=1.1, patience=10, max_lr=2e-4, min_lr=1e-6 + optimizer_ginka, factor=0.9, increase_factor=2, patience=10, max_lr=2e-4, min_lr=1e-6 ) criterion = VAELoss() @@ -166,11 +167,11 @@ def train(): # 先使用训练集的损失值,因为过拟合比较严重,后续再想办法 if avg_loss < SELF_GATE: - gate_epochs += 1 + prob_epochs += 1 - if gate_epochs >= GATE_EPOCH and self_prob < 1: + if prob_epochs >= GATE_EPOCH and self_prob < 1: self_prob += 0.01 - gate_epochs = 0 + prob_epochs = 0 scheduler_ginka.step(avg_loss, self_prob)