From 9887abcd014226ea2897f2a02cbfb3d47802778b Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Fri, 6 Feb 2026 01:23:58 +0800 Subject: [PATCH] fix: zero_grad --- ginka/train_vae.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ginka/train_vae.py b/ginka/train_vae.py index d72fa36..ed3126e 100644 --- a/ginka/train_vae.py +++ b/ginka/train_vae.py @@ -87,7 +87,7 @@ def train(): dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 64, shuffle=True) optimizer_ginka = optim.AdamW(vae.parameters(), lr=2e-4, weight_decay=1e-4) - scheduler_ginka = optim.lr_scheduler.ReduceLROnPlateau(optimizer_ginka, factor=0.9, patience=40) + scheduler_ginka = optim.lr_scheduler.ReduceLROnPlateau(optimizer_ginka, factor=0.9, patience=40, min_lr=1e-6) criterion = VAELoss() @@ -116,6 +116,7 @@ def train(): for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): target_map = batch["target_map"].to(device) + optimizer_ginka.zero_grad() fake_logits, z = vae(target_map, 1 - gt_prob) loss = criterion.vae_loss(fake_logits, target_map) @@ -204,7 +205,7 @@ def train(): ) if avg_loss_val < 0.5 and gt_prob > 0: - gt_prob -= 0.1 + gt_prob -= 0.01 print("Train ended.") torch.save({