fix: zero_grad

This commit is contained in:
unanmed 2026-02-06 01:23:58 +08:00
parent b01357d99e
commit 9887abcd01

View File

@ -87,7 +87,7 @@ def train():
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 64, shuffle=True)
optimizer_ginka = optim.AdamW(vae.parameters(), lr=2e-4, weight_decay=1e-4)
scheduler_ginka = optim.lr_scheduler.ReduceLROnPlateau(optimizer_ginka, factor=0.9, patience=40)
scheduler_ginka = optim.lr_scheduler.ReduceLROnPlateau(optimizer_ginka, factor=0.9, patience=40, min_lr=1e-6)
criterion = VAELoss()
@ -116,6 +116,7 @@ def train():
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
target_map = batch["target_map"].to(device)
optimizer_ginka.zero_grad()
fake_logits, z = vae(target_map, 1 - gt_prob)
loss = criterion.vae_loss(fake_logits, target_map)
@ -204,7 +205,7 @@ def train():
)
if avg_loss_val < 0.5 and gt_prob > 0:
gt_prob -= 0.1
gt_prob -= 0.01
print("Train ended.")
torch.save({