diff --git a/ginka/train_vae.py b/ginka/train_vae.py index fe66d4b..d72fa36 100644 --- a/ginka/train_vae.py +++ b/ginka/train_vae.py @@ -54,7 +54,6 @@ from shared.image import matrix_to_image_cv BATCH_SIZE = 128 LATENT_DIM = 48 -KL_BETA = 0.01 device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") os.makedirs("result", exist_ok=True) @@ -63,15 +62,6 @@ os.makedirs("result/ginka_vae_img", exist_ok=True) disable_tqdm = not sys.stdout.isatty() -def gt_prob(epoch: int, max_epoch: int) -> float: - progress = epoch / max_epoch - if progress < 0.2: - return 1 - elif progress < 0.8: - return 1 - (progress - 0.2) / 0.6 - else: - return 0 - def parse_arguments(): parser = argparse.ArgumentParser(description="training codes") parser.add_argument("--resume", type=bool, default=False) @@ -96,10 +86,12 @@ def train(): dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 64, shuffle=True) - optimizer_ginka = optim.AdamW(vae.parameters(), lr=1e-4, weight_decay=1e-4) - scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6) + optimizer_ginka = optim.AdamW(vae.parameters(), lr=2e-4, weight_decay=1e-4) + scheduler_ginka = optim.lr_scheduler.ReduceLROnPlateau(optimizer_ginka, factor=0.9, patience=40) criterion = VAELoss() + + gt_prob = 1 # 用于生成图片 tile_dict = dict() @@ -124,9 +116,9 @@ def train(): for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): target_map = batch["target_map"].to(device) - fake_logits, mu, logvar = vae(target_map, 1 - gt_prob(epoch, args.epochs)) + fake_logits, z = vae(target_map, 1 - gt_prob) - loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA) + loss = criterion.vae_loss(fake_logits, target_map) loss.backward() torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1.0) @@ -140,7 +132,7 @@ def train(): f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}" ) - scheduler_ginka.step() + scheduler_ginka.step(avg_loss) # 每若干轮输出一次图片,并保存检查点 if (epoch + 1) % args.checkpoint == 0: @@ -151,7 +143,6 @@ def train(): }, f"result/rnn/ginka-{epoch + 1}.pth") val_loss_total = torch.Tensor([0]).to(device) - reco_loss_total = torch.Tensor([0]).to(device) with torch.no_grad(): idx = 0 gap = 5 @@ -161,9 +152,9 @@ def train(): for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): target_map = batch["target_map"].to(device) - fake_logits, z = vae(target_map, 1 - gt_prob(epoch, args.epochs)) + fake_logits, z = vae(target_map, 1 - gt_prob) - loss = criterion.vae_loss(fake_logits, target_map, z, KL_BETA) + loss = criterion.vae_loss(fake_logits, target_map) val_loss_total += loss.detach() fake_map = torch.argmax(fake_logits, dim=1).cpu().numpy() @@ -212,6 +203,9 @@ def train(): f"Loss: {avg_loss_val:.6f}" ) + if avg_loss_val < 0.5 and gt_prob > 0: + gt_prob -= 0.1 + print("Train ended.") torch.save({ "model_state": vae.state_dict(),