mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-22 11:01:12 +08:00
chore: 调整超参数
This commit is contained in:
parent
05b3b7c171
commit
f22afb5f72
@ -63,7 +63,7 @@ disable_tqdm = not sys.stdout.isatty()
|
|||||||
|
|
||||||
def gt_prob(epoch: int, max_epoch: int) -> float:
|
def gt_prob(epoch: int, max_epoch: int) -> float:
|
||||||
progress = epoch / max_epoch
|
progress = epoch / max_epoch
|
||||||
return max(1.2 * progress - 0.2, 0)
|
return max(2 * progress - 1, 0)
|
||||||
|
|
||||||
def parse_arguments():
|
def parse_arguments():
|
||||||
parser = argparse.ArgumentParser(description="training codes")
|
parser = argparse.ArgumentParser(description="training codes")
|
||||||
@ -89,7 +89,7 @@ def train():
|
|||||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 4, shuffle=True)
|
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 4, shuffle=True)
|
||||||
|
|
||||||
optimizer_ginka = optim.AdamW(vae.parameters(), lr=1e-4, weight_decay=1e-4)
|
optimizer_ginka = optim.AdamW(vae.parameters(), lr=3e-4, weight_decay=1e-4)
|
||||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)
|
scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)
|
||||||
|
|
||||||
criterion = VAELoss()
|
criterion = VAELoss()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user