mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-21 02:11:13 +08:00
fix: zero_grad
This commit is contained in:
parent
b01357d99e
commit
9887abcd01
@ -87,7 +87,7 @@ def train():
|
|||||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 64, shuffle=True)
|
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 64, shuffle=True)
|
||||||
|
|
||||||
optimizer_ginka = optim.AdamW(vae.parameters(), lr=2e-4, weight_decay=1e-4)
|
optimizer_ginka = optim.AdamW(vae.parameters(), lr=2e-4, weight_decay=1e-4)
|
||||||
scheduler_ginka = optim.lr_scheduler.ReduceLROnPlateau(optimizer_ginka, factor=0.9, patience=40)
|
scheduler_ginka = optim.lr_scheduler.ReduceLROnPlateau(optimizer_ginka, factor=0.9, patience=40, min_lr=1e-6)
|
||||||
|
|
||||||
criterion = VAELoss()
|
criterion = VAELoss()
|
||||||
|
|
||||||
@ -116,6 +116,7 @@ def train():
|
|||||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||||
target_map = batch["target_map"].to(device)
|
target_map = batch["target_map"].to(device)
|
||||||
|
|
||||||
|
optimizer_ginka.zero_grad()
|
||||||
fake_logits, z = vae(target_map, 1 - gt_prob)
|
fake_logits, z = vae(target_map, 1 - gt_prob)
|
||||||
|
|
||||||
loss = criterion.vae_loss(fake_logits, target_map)
|
loss = criterion.vae_loss(fake_logits, target_map)
|
||||||
@ -204,7 +205,7 @@ def train():
|
|||||||
)
|
)
|
||||||
|
|
||||||
if avg_loss_val < 0.5 and gt_prob > 0:
|
if avg_loss_val < 0.5 and gt_prob > 0:
|
||||||
gt_prob -= 0.1
|
gt_prob -= 0.01
|
||||||
|
|
||||||
print("Train ended.")
|
print("Train ended.")
|
||||||
torch.save({
|
torch.save({
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user