diff --git a/ginka/train_vae.py b/ginka/train_vae.py index 7ec75e4..e440ac8 100644 --- a/ginka/train_vae.py +++ b/ginka/train_vae.py @@ -12,6 +12,7 @@ from torch_geometric.loader import DataLoader from tqdm import tqdm from .vae_rnn.vae import GinkaVAE from .vae_rnn.loss import VAELoss +from .vae_rnn.scheduler import VAEScheduler from .dataset import GinkaRNNDataset from shared.image import matrix_to_image_cv @@ -92,11 +93,14 @@ def train(): dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 64, shuffle=True) optimizer_ginka = optim.AdamW(vae.parameters(), lr=2e-4, weight_decay=1e-4) - scheduler_ginka = optim.lr_scheduler.ReduceLROnPlateau(optimizer_ginka, factor=0.9, patience=40, min_lr=1e-6) + # 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习 + scheduler_ginka = VAEScheduler( + optimizer_ginka, factor=0.9, increase_factor=1.1, patience=10, max_lr=2e-4, min_lr=1e-6 + ) criterion = VAELoss() - gt_prob = 1 + self_prob = 0 # 用于生成图片 tile_dict = dict() @@ -124,7 +128,7 @@ def train(): target_map = batch["target_map"].to(device) optimizer_ginka.zero_grad() - fake_logits, mu, logvar = vae(target_map, 1 - gt_prob) + fake_logits, mu, logvar = vae(target_map, self_prob) loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA) @@ -141,7 +145,7 @@ def train(): tqdm.write( f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + f"E: {epoch + 1} | Loss: {avg_loss:.6f} | Reco: {avg_reco_loss:.6f} | " + - f"KL: {avg_kl_loss:.6f} | Prob: {gt_prob:.2f} | LR: {optimizer_ginka.param_groups[0]['lr']:.6f}" + f"KL: {avg_kl_loss:.6f} | Prob: {self_prob:.2f} | LR: {scheduler_ginka.get_last_lr()[0]:.6f}" ) # 验证集 @@ -157,10 +161,10 @@ def train(): # avg_loss_val = val_loss_total.item() / len(dataloader_val) # 先使用训练集的损失值,因为过拟合比较严重,后续再想办法 - if avg_loss < 0.5 and gt_prob > 0: - gt_prob -= 0.01 + if avg_loss < 0.5 and self_prob < 1: + self_prob += 0.01 - scheduler_ginka.step(avg_loss) + scheduler_ginka.step(avg_loss, self_prob) # 每若干轮输出一次图片,并保存检查点 if (epoch + 1) % args.checkpoint == 0: @@ -182,7 +186,7 @@ def train(): for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): target_map = batch["target_map"].to(device) - fake_logits, mu, logvar = vae(target_map, 1 - gt_prob) + fake_logits, mu, logvar = vae(target_map, self_prob) loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA) val_loss_total += loss.detach() @@ -239,9 +243,6 @@ def train(): f"Loss: {avg_loss_val:.6f} | Reco: {avg_reco_loss_val:.6f} | KL: {avg_kl_loss_val:.6f}" ) - if avg_loss_val < 0.5 and gt_prob > 0: - gt_prob -= 0.01 - print("Train ended.") torch.save({ "model_state": vae.state_dict(), diff --git a/ginka/vae_rnn/scheduler.py b/ginka/vae_rnn/scheduler.py new file mode 100644 index 0000000..b5c40df --- /dev/null +++ b/ginka/vae_rnn/scheduler.py @@ -0,0 +1,43 @@ +import torch + +class VAEScheduler(torch.optim.lr_scheduler.ReduceLROnPlateau): + def __init__( + self, optimizer, mode="min", factor=0.1, patience=10, threshold=0.0001, + threshold_mode="rel", cooldown=0, min_lr=0, eps=1e-8, verbose="deprecated", + max_lr=1e-2, increase_factor=2, start_prob=0 + ): + super().__init__( + optimizer, mode, factor, patience, threshold, + threshold_mode, cooldown, min_lr, eps, verbose + ) + self.max_lr = max_lr + self.increase_factor = increase_factor + self.last_prob = start_prob + + if isinstance(max_lr, (list, tuple)): + if len(max_lr) != len(optimizer.param_groups): + raise ValueError( + f"expected {len(optimizer.param_groups)} max_lrs, got {len(max_lr)}" + ) + self.default_max_lr = None + self.max_lrs = list(max_lr) + else: + self.default_max_lr = max_lr + self.max_lrs = [max_lr] * len(optimizer.param_groups) + + def step(self, metrics, prob: float, epoch=None): + if prob > self.last_prob: + self.best = metrics + self.num_bad_epochs = 0 + self.last_prob = prob + self._increase_lr() + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + else: + return super().step(metrics, epoch) + + def _increase_lr(self): + for i, param_group in enumerate(self.optimizer.param_groups): + old_lr = float(param_group["lr"]) + new_lr = min(old_lr * self.increase_factor, self.max_lrs[i]) + if new_lr - old_lr > self.eps: + param_group["lr"] = new_lr