mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
fix: zero_grad
This commit is contained in:
parent
b01357d99e
commit
9887abcd01
@ -87,7 +87,7 @@ def train():
|
||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 64, shuffle=True)
|
||||
|
||||
optimizer_ginka = optim.AdamW(vae.parameters(), lr=2e-4, weight_decay=1e-4)
|
||||
scheduler_ginka = optim.lr_scheduler.ReduceLROnPlateau(optimizer_ginka, factor=0.9, patience=40)
|
||||
scheduler_ginka = optim.lr_scheduler.ReduceLROnPlateau(optimizer_ginka, factor=0.9, patience=40, min_lr=1e-6)
|
||||
|
||||
criterion = VAELoss()
|
||||
|
||||
@ -116,6 +116,7 @@ def train():
|
||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||
target_map = batch["target_map"].to(device)
|
||||
|
||||
optimizer_ginka.zero_grad()
|
||||
fake_logits, z = vae(target_map, 1 - gt_prob)
|
||||
|
||||
loss = criterion.vae_loss(fake_logits, target_map)
|
||||
@ -204,7 +205,7 @@ def train():
|
||||
)
|
||||
|
||||
if avg_loss_val < 0.5 and gt_prob > 0:
|
||||
gt_prob -= 0.1
|
||||
gt_prob -= 0.01
|
||||
|
||||
print("Train ended.")
|
||||
torch.save({
|
||||
|
||||
Loading…
Reference in New Issue
Block a user