mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-15 21:57:52 +08:00
feat: 更改调度
This commit is contained in:
parent
bc73cf9cc3
commit
b01357d99e
@ -54,7 +54,6 @@ from shared.image import matrix_to_image_cv
|
||||
|
||||
BATCH_SIZE = 128
|
||||
LATENT_DIM = 48
|
||||
KL_BETA = 0.01
|
||||
|
||||
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
||||
os.makedirs("result", exist_ok=True)
|
||||
@ -63,15 +62,6 @@ os.makedirs("result/ginka_vae_img", exist_ok=True)
|
||||
|
||||
disable_tqdm = not sys.stdout.isatty()
|
||||
|
||||
def gt_prob(epoch: int, max_epoch: int) -> float:
|
||||
progress = epoch / max_epoch
|
||||
if progress < 0.2:
|
||||
return 1
|
||||
elif progress < 0.8:
|
||||
return 1 - (progress - 0.2) / 0.6
|
||||
else:
|
||||
return 0
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="training codes")
|
||||
parser.add_argument("--resume", type=bool, default=False)
|
||||
@ -96,10 +86,12 @@ def train():
|
||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 64, shuffle=True)
|
||||
|
||||
optimizer_ginka = optim.AdamW(vae.parameters(), lr=1e-4, weight_decay=1e-4)
|
||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)
|
||||
optimizer_ginka = optim.AdamW(vae.parameters(), lr=2e-4, weight_decay=1e-4)
|
||||
scheduler_ginka = optim.lr_scheduler.ReduceLROnPlateau(optimizer_ginka, factor=0.9, patience=40)
|
||||
|
||||
criterion = VAELoss()
|
||||
|
||||
gt_prob = 1
|
||||
|
||||
# 用于生成图片
|
||||
tile_dict = dict()
|
||||
@ -124,9 +116,9 @@ def train():
|
||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||
target_map = batch["target_map"].to(device)
|
||||
|
||||
fake_logits, mu, logvar = vae(target_map, 1 - gt_prob(epoch, args.epochs))
|
||||
fake_logits, z = vae(target_map, 1 - gt_prob)
|
||||
|
||||
loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
|
||||
loss = criterion.vae_loss(fake_logits, target_map)
|
||||
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1.0)
|
||||
@ -140,7 +132,7 @@ def train():
|
||||
f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
|
||||
)
|
||||
|
||||
scheduler_ginka.step()
|
||||
scheduler_ginka.step(avg_loss)
|
||||
|
||||
# 每若干轮输出一次图片,并保存检查点
|
||||
if (epoch + 1) % args.checkpoint == 0:
|
||||
@ -151,7 +143,6 @@ def train():
|
||||
}, f"result/rnn/ginka-{epoch + 1}.pth")
|
||||
|
||||
val_loss_total = torch.Tensor([0]).to(device)
|
||||
reco_loss_total = torch.Tensor([0]).to(device)
|
||||
with torch.no_grad():
|
||||
idx = 0
|
||||
gap = 5
|
||||
@ -161,9 +152,9 @@ def train():
|
||||
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
||||
target_map = batch["target_map"].to(device)
|
||||
|
||||
fake_logits, z = vae(target_map, 1 - gt_prob(epoch, args.epochs))
|
||||
fake_logits, z = vae(target_map, 1 - gt_prob)
|
||||
|
||||
loss = criterion.vae_loss(fake_logits, target_map, z, KL_BETA)
|
||||
loss = criterion.vae_loss(fake_logits, target_map)
|
||||
val_loss_total += loss.detach()
|
||||
|
||||
fake_map = torch.argmax(fake_logits, dim=1).cpu().numpy()
|
||||
@ -212,6 +203,9 @@ def train():
|
||||
f"Loss: {avg_loss_val:.6f}"
|
||||
)
|
||||
|
||||
if avg_loss_val < 0.5 and gt_prob > 0:
|
||||
gt_prob -= 0.1
|
||||
|
||||
print("Train ended.")
|
||||
torch.save({
|
||||
"model_state": vae.state_dict(),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user