feat: 自定义调度器

This commit is contained in:
unanmed 2026-02-09 14:59:09 +08:00
parent 45cfa3b611
commit 1352d64a50
2 changed files with 55 additions and 11 deletions

View File

@ -12,6 +12,7 @@ from torch_geometric.loader import DataLoader
from tqdm import tqdm
from .vae_rnn.vae import GinkaVAE
from .vae_rnn.loss import VAELoss
from .vae_rnn.scheduler import VAEScheduler
from .dataset import GinkaRNNDataset
from shared.image import matrix_to_image_cv
@ -92,11 +93,14 @@ def train():
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 64, shuffle=True)
optimizer_ginka = optim.AdamW(vae.parameters(), lr=2e-4, weight_decay=1e-4)
scheduler_ginka = optim.lr_scheduler.ReduceLROnPlateau(optimizer_ginka, factor=0.9, patience=40, min_lr=1e-6)
# 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习
scheduler_ginka = VAEScheduler(
optimizer_ginka, factor=0.9, increase_factor=1.1, patience=10, max_lr=2e-4, min_lr=1e-6
)
criterion = VAELoss()
gt_prob = 1
self_prob = 0
# 用于生成图片
tile_dict = dict()
@ -124,7 +128,7 @@ def train():
target_map = batch["target_map"].to(device)
optimizer_ginka.zero_grad()
fake_logits, mu, logvar = vae(target_map, 1 - gt_prob)
fake_logits, mu, logvar = vae(target_map, self_prob)
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
@ -141,7 +145,7 @@ def train():
tqdm.write(
f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
f"E: {epoch + 1} | Loss: {avg_loss:.6f} | Reco: {avg_reco_loss:.6f} | " +
f"KL: {avg_kl_loss:.6f} | Prob: {gt_prob:.2f} | LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
f"KL: {avg_kl_loss:.6f} | Prob: {self_prob:.2f} | LR: {scheduler_ginka.get_last_lr()[0]:.6f}"
)
# 验证集
@ -157,10 +161,10 @@ def train():
# avg_loss_val = val_loss_total.item() / len(dataloader_val)
# 先使用训练集的损失值,因为过拟合比较严重,后续再想办法
if avg_loss < 0.5 and gt_prob > 0:
gt_prob -= 0.01
if avg_loss < 0.5 and self_prob < 1:
self_prob += 0.01
scheduler_ginka.step(avg_loss)
scheduler_ginka.step(avg_loss, self_prob)
# 每若干轮输出一次图片,并保存检查点
if (epoch + 1) % args.checkpoint == 0:
@ -182,7 +186,7 @@ def train():
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
target_map = batch["target_map"].to(device)
fake_logits, mu, logvar = vae(target_map, 1 - gt_prob)
fake_logits, mu, logvar = vae(target_map, self_prob)
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
val_loss_total += loss.detach()
@ -239,9 +243,6 @@ def train():
f"Loss: {avg_loss_val:.6f} | Reco: {avg_reco_loss_val:.6f} | KL: {avg_kl_loss_val:.6f}"
)
if avg_loss_val < 0.5 and gt_prob > 0:
gt_prob -= 0.01
print("Train ended.")
torch.save({
"model_state": vae.state_dict(),

View File

@ -0,0 +1,43 @@
import torch
class VAEScheduler(torch.optim.lr_scheduler.ReduceLROnPlateau):
def __init__(
self, optimizer, mode="min", factor=0.1, patience=10, threshold=0.0001,
threshold_mode="rel", cooldown=0, min_lr=0, eps=1e-8, verbose="deprecated",
max_lr=1e-2, increase_factor=2, start_prob=0
):
super().__init__(
optimizer, mode, factor, patience, threshold,
threshold_mode, cooldown, min_lr, eps, verbose
)
self.max_lr = max_lr
self.increase_factor = increase_factor
self.last_prob = start_prob
if isinstance(max_lr, (list, tuple)):
if len(max_lr) != len(optimizer.param_groups):
raise ValueError(
f"expected {len(optimizer.param_groups)} max_lrs, got {len(max_lr)}"
)
self.default_max_lr = None
self.max_lrs = list(max_lr)
else:
self.default_max_lr = max_lr
self.max_lrs = [max_lr] * len(optimizer.param_groups)
def step(self, metrics, prob: float, epoch=None):
if prob > self.last_prob:
self.best = metrics
self.num_bad_epochs = 0
self.last_prob = prob
self._increase_lr()
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
else:
return super().step(metrics, epoch)
def _increase_lr(self):
for i, param_group in enumerate(self.optimizer.param_groups):
old_lr = float(param_group["lr"])
new_lr = min(old_lr * self.increase_factor, self.max_lrs[i])
if new_lr - old_lr > self.eps:
param_group["lr"] = new_lr