mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 自定义调度器
This commit is contained in:
parent
45cfa3b611
commit
1352d64a50
@ -12,6 +12,7 @@ from torch_geometric.loader import DataLoader
|
||||
from tqdm import tqdm
|
||||
from .vae_rnn.vae import GinkaVAE
|
||||
from .vae_rnn.loss import VAELoss
|
||||
from .vae_rnn.scheduler import VAEScheduler
|
||||
from .dataset import GinkaRNNDataset
|
||||
from shared.image import matrix_to_image_cv
|
||||
|
||||
@ -92,11 +93,14 @@ def train():
|
||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 64, shuffle=True)
|
||||
|
||||
optimizer_ginka = optim.AdamW(vae.parameters(), lr=2e-4, weight_decay=1e-4)
|
||||
scheduler_ginka = optim.lr_scheduler.ReduceLROnPlateau(optimizer_ginka, factor=0.9, patience=40, min_lr=1e-6)
|
||||
# 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习
|
||||
scheduler_ginka = VAEScheduler(
|
||||
optimizer_ginka, factor=0.9, increase_factor=1.1, patience=10, max_lr=2e-4, min_lr=1e-6
|
||||
)
|
||||
|
||||
criterion = VAELoss()
|
||||
|
||||
gt_prob = 1
|
||||
self_prob = 0
|
||||
|
||||
# 用于生成图片
|
||||
tile_dict = dict()
|
||||
@ -124,7 +128,7 @@ def train():
|
||||
target_map = batch["target_map"].to(device)
|
||||
|
||||
optimizer_ginka.zero_grad()
|
||||
fake_logits, mu, logvar = vae(target_map, 1 - gt_prob)
|
||||
fake_logits, mu, logvar = vae(target_map, self_prob)
|
||||
|
||||
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
|
||||
|
||||
@ -141,7 +145,7 @@ def train():
|
||||
tqdm.write(
|
||||
f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
||||
f"E: {epoch + 1} | Loss: {avg_loss:.6f} | Reco: {avg_reco_loss:.6f} | " +
|
||||
f"KL: {avg_kl_loss:.6f} | Prob: {gt_prob:.2f} | LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
|
||||
f"KL: {avg_kl_loss:.6f} | Prob: {self_prob:.2f} | LR: {scheduler_ginka.get_last_lr()[0]:.6f}"
|
||||
)
|
||||
|
||||
# 验证集
|
||||
@ -157,10 +161,10 @@ def train():
|
||||
|
||||
# avg_loss_val = val_loss_total.item() / len(dataloader_val)
|
||||
# 先使用训练集的损失值,因为过拟合比较严重,后续再想办法
|
||||
if avg_loss < 0.5 and gt_prob > 0:
|
||||
gt_prob -= 0.01
|
||||
if avg_loss < 0.5 and self_prob < 1:
|
||||
self_prob += 0.01
|
||||
|
||||
scheduler_ginka.step(avg_loss)
|
||||
scheduler_ginka.step(avg_loss, self_prob)
|
||||
|
||||
# 每若干轮输出一次图片,并保存检查点
|
||||
if (epoch + 1) % args.checkpoint == 0:
|
||||
@ -182,7 +186,7 @@ def train():
|
||||
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
||||
target_map = batch["target_map"].to(device)
|
||||
|
||||
fake_logits, mu, logvar = vae(target_map, 1 - gt_prob)
|
||||
fake_logits, mu, logvar = vae(target_map, self_prob)
|
||||
|
||||
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
|
||||
val_loss_total += loss.detach()
|
||||
@ -239,9 +243,6 @@ def train():
|
||||
f"Loss: {avg_loss_val:.6f} | Reco: {avg_reco_loss_val:.6f} | KL: {avg_kl_loss_val:.6f}"
|
||||
)
|
||||
|
||||
if avg_loss_val < 0.5 and gt_prob > 0:
|
||||
gt_prob -= 0.01
|
||||
|
||||
print("Train ended.")
|
||||
torch.save({
|
||||
"model_state": vae.state_dict(),
|
||||
|
||||
43
ginka/vae_rnn/scheduler.py
Normal file
43
ginka/vae_rnn/scheduler.py
Normal file
@ -0,0 +1,43 @@
|
||||
import torch
|
||||
|
||||
class VAEScheduler(torch.optim.lr_scheduler.ReduceLROnPlateau):
|
||||
def __init__(
|
||||
self, optimizer, mode="min", factor=0.1, patience=10, threshold=0.0001,
|
||||
threshold_mode="rel", cooldown=0, min_lr=0, eps=1e-8, verbose="deprecated",
|
||||
max_lr=1e-2, increase_factor=2, start_prob=0
|
||||
):
|
||||
super().__init__(
|
||||
optimizer, mode, factor, patience, threshold,
|
||||
threshold_mode, cooldown, min_lr, eps, verbose
|
||||
)
|
||||
self.max_lr = max_lr
|
||||
self.increase_factor = increase_factor
|
||||
self.last_prob = start_prob
|
||||
|
||||
if isinstance(max_lr, (list, tuple)):
|
||||
if len(max_lr) != len(optimizer.param_groups):
|
||||
raise ValueError(
|
||||
f"expected {len(optimizer.param_groups)} max_lrs, got {len(max_lr)}"
|
||||
)
|
||||
self.default_max_lr = None
|
||||
self.max_lrs = list(max_lr)
|
||||
else:
|
||||
self.default_max_lr = max_lr
|
||||
self.max_lrs = [max_lr] * len(optimizer.param_groups)
|
||||
|
||||
def step(self, metrics, prob: float, epoch=None):
|
||||
if prob > self.last_prob:
|
||||
self.best = metrics
|
||||
self.num_bad_epochs = 0
|
||||
self.last_prob = prob
|
||||
self._increase_lr()
|
||||
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
|
||||
else:
|
||||
return super().step(metrics, epoch)
|
||||
|
||||
def _increase_lr(self):
|
||||
for i, param_group in enumerate(self.optimizer.param_groups):
|
||||
old_lr = float(param_group["lr"])
|
||||
new_lr = min(old_lr * self.increase_factor, self.max_lrs[i])
|
||||
if new_lr - old_lr > self.eps:
|
||||
param_group["lr"] = new_lr
|
||||
Loading…
Reference in New Issue
Block a user