feat: 更改调度

This commit is contained in:
unanmed 2026-02-06 01:20:27 +08:00
parent bc73cf9cc3
commit b01357d99e

View File

@ -54,7 +54,6 @@ from shared.image import matrix_to_image_cv
BATCH_SIZE = 128
LATENT_DIM = 48
KL_BETA = 0.01
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
@ -63,15 +62,6 @@ os.makedirs("result/ginka_vae_img", exist_ok=True)
disable_tqdm = not sys.stdout.isatty()
def gt_prob(epoch: int, max_epoch: int) -> float:
progress = epoch / max_epoch
if progress < 0.2:
return 1
elif progress < 0.8:
return 1 - (progress - 0.2) / 0.6
else:
return 0
def parse_arguments():
parser = argparse.ArgumentParser(description="training codes")
parser.add_argument("--resume", type=bool, default=False)
@ -96,10 +86,12 @@ def train():
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 64, shuffle=True)
optimizer_ginka = optim.AdamW(vae.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)
optimizer_ginka = optim.AdamW(vae.parameters(), lr=2e-4, weight_decay=1e-4)
scheduler_ginka = optim.lr_scheduler.ReduceLROnPlateau(optimizer_ginka, factor=0.9, patience=40)
criterion = VAELoss()
gt_prob = 1
# 用于生成图片
tile_dict = dict()
@ -124,9 +116,9 @@ def train():
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
target_map = batch["target_map"].to(device)
fake_logits, mu, logvar = vae(target_map, 1 - gt_prob(epoch, args.epochs))
fake_logits, z = vae(target_map, 1 - gt_prob)
loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
loss = criterion.vae_loss(fake_logits, target_map)
loss.backward()
torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1.0)
@ -140,7 +132,7 @@ def train():
f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
)
scheduler_ginka.step()
scheduler_ginka.step(avg_loss)
# 每若干轮输出一次图片,并保存检查点
if (epoch + 1) % args.checkpoint == 0:
@ -151,7 +143,6 @@ def train():
}, f"result/rnn/ginka-{epoch + 1}.pth")
val_loss_total = torch.Tensor([0]).to(device)
reco_loss_total = torch.Tensor([0]).to(device)
with torch.no_grad():
idx = 0
gap = 5
@ -161,9 +152,9 @@ def train():
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
target_map = batch["target_map"].to(device)
fake_logits, z = vae(target_map, 1 - gt_prob(epoch, args.epochs))
fake_logits, z = vae(target_map, 1 - gt_prob)
loss = criterion.vae_loss(fake_logits, target_map, z, KL_BETA)
loss = criterion.vae_loss(fake_logits, target_map)
val_loss_total += loss.detach()
fake_map = torch.argmax(fake_logits, dim=1).cpu().numpy()
@ -212,6 +203,9 @@ def train():
f"Loss: {avg_loss_val:.6f}"
)
if avg_loss_val < 0.5 and gt_prob > 0:
gt_prob -= 0.1
print("Train ended.")
torch.save({
"model_state": vae.state_dict(),