mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-20 17:22:41 +08:00
feat: 自定义调度器
This commit is contained in:
parent
45cfa3b611
commit
1352d64a50
@ -12,6 +12,7 @@ from torch_geometric.loader import DataLoader
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from .vae_rnn.vae import GinkaVAE
|
from .vae_rnn.vae import GinkaVAE
|
||||||
from .vae_rnn.loss import VAELoss
|
from .vae_rnn.loss import VAELoss
|
||||||
|
from .vae_rnn.scheduler import VAEScheduler
|
||||||
from .dataset import GinkaRNNDataset
|
from .dataset import GinkaRNNDataset
|
||||||
from shared.image import matrix_to_image_cv
|
from shared.image import matrix_to_image_cv
|
||||||
|
|
||||||
@ -92,11 +93,14 @@ def train():
|
|||||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 64, shuffle=True)
|
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 64, shuffle=True)
|
||||||
|
|
||||||
optimizer_ginka = optim.AdamW(vae.parameters(), lr=2e-4, weight_decay=1e-4)
|
optimizer_ginka = optim.AdamW(vae.parameters(), lr=2e-4, weight_decay=1e-4)
|
||||||
scheduler_ginka = optim.lr_scheduler.ReduceLROnPlateau(optimizer_ginka, factor=0.9, patience=40, min_lr=1e-6)
|
# 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习
|
||||||
|
scheduler_ginka = VAEScheduler(
|
||||||
|
optimizer_ginka, factor=0.9, increase_factor=1.1, patience=10, max_lr=2e-4, min_lr=1e-6
|
||||||
|
)
|
||||||
|
|
||||||
criterion = VAELoss()
|
criterion = VAELoss()
|
||||||
|
|
||||||
gt_prob = 1
|
self_prob = 0
|
||||||
|
|
||||||
# 用于生成图片
|
# 用于生成图片
|
||||||
tile_dict = dict()
|
tile_dict = dict()
|
||||||
@ -124,7 +128,7 @@ def train():
|
|||||||
target_map = batch["target_map"].to(device)
|
target_map = batch["target_map"].to(device)
|
||||||
|
|
||||||
optimizer_ginka.zero_grad()
|
optimizer_ginka.zero_grad()
|
||||||
fake_logits, mu, logvar = vae(target_map, 1 - gt_prob)
|
fake_logits, mu, logvar = vae(target_map, self_prob)
|
||||||
|
|
||||||
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
|
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
|
||||||
|
|
||||||
@ -141,7 +145,7 @@ def train():
|
|||||||
tqdm.write(
|
tqdm.write(
|
||||||
f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
||||||
f"E: {epoch + 1} | Loss: {avg_loss:.6f} | Reco: {avg_reco_loss:.6f} | " +
|
f"E: {epoch + 1} | Loss: {avg_loss:.6f} | Reco: {avg_reco_loss:.6f} | " +
|
||||||
f"KL: {avg_kl_loss:.6f} | Prob: {gt_prob:.2f} | LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
|
f"KL: {avg_kl_loss:.6f} | Prob: {self_prob:.2f} | LR: {scheduler_ginka.get_last_lr()[0]:.6f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 验证集
|
# 验证集
|
||||||
@ -157,10 +161,10 @@ def train():
|
|||||||
|
|
||||||
# avg_loss_val = val_loss_total.item() / len(dataloader_val)
|
# avg_loss_val = val_loss_total.item() / len(dataloader_val)
|
||||||
# 先使用训练集的损失值,因为过拟合比较严重,后续再想办法
|
# 先使用训练集的损失值,因为过拟合比较严重,后续再想办法
|
||||||
if avg_loss < 0.5 and gt_prob > 0:
|
if avg_loss < 0.5 and self_prob < 1:
|
||||||
gt_prob -= 0.01
|
self_prob += 0.01
|
||||||
|
|
||||||
scheduler_ginka.step(avg_loss)
|
scheduler_ginka.step(avg_loss, self_prob)
|
||||||
|
|
||||||
# 每若干轮输出一次图片,并保存检查点
|
# 每若干轮输出一次图片,并保存检查点
|
||||||
if (epoch + 1) % args.checkpoint == 0:
|
if (epoch + 1) % args.checkpoint == 0:
|
||||||
@ -182,7 +186,7 @@ def train():
|
|||||||
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
||||||
target_map = batch["target_map"].to(device)
|
target_map = batch["target_map"].to(device)
|
||||||
|
|
||||||
fake_logits, mu, logvar = vae(target_map, 1 - gt_prob)
|
fake_logits, mu, logvar = vae(target_map, self_prob)
|
||||||
|
|
||||||
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
|
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
|
||||||
val_loss_total += loss.detach()
|
val_loss_total += loss.detach()
|
||||||
@ -239,9 +243,6 @@ def train():
|
|||||||
f"Loss: {avg_loss_val:.6f} | Reco: {avg_reco_loss_val:.6f} | KL: {avg_kl_loss_val:.6f}"
|
f"Loss: {avg_loss_val:.6f} | Reco: {avg_reco_loss_val:.6f} | KL: {avg_kl_loss_val:.6f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if avg_loss_val < 0.5 and gt_prob > 0:
|
|
||||||
gt_prob -= 0.01
|
|
||||||
|
|
||||||
print("Train ended.")
|
print("Train ended.")
|
||||||
torch.save({
|
torch.save({
|
||||||
"model_state": vae.state_dict(),
|
"model_state": vae.state_dict(),
|
||||||
|
|||||||
43
ginka/vae_rnn/scheduler.py
Normal file
43
ginka/vae_rnn/scheduler.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
class VAEScheduler(torch.optim.lr_scheduler.ReduceLROnPlateau):
|
||||||
|
def __init__(
|
||||||
|
self, optimizer, mode="min", factor=0.1, patience=10, threshold=0.0001,
|
||||||
|
threshold_mode="rel", cooldown=0, min_lr=0, eps=1e-8, verbose="deprecated",
|
||||||
|
max_lr=1e-2, increase_factor=2, start_prob=0
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
optimizer, mode, factor, patience, threshold,
|
||||||
|
threshold_mode, cooldown, min_lr, eps, verbose
|
||||||
|
)
|
||||||
|
self.max_lr = max_lr
|
||||||
|
self.increase_factor = increase_factor
|
||||||
|
self.last_prob = start_prob
|
||||||
|
|
||||||
|
if isinstance(max_lr, (list, tuple)):
|
||||||
|
if len(max_lr) != len(optimizer.param_groups):
|
||||||
|
raise ValueError(
|
||||||
|
f"expected {len(optimizer.param_groups)} max_lrs, got {len(max_lr)}"
|
||||||
|
)
|
||||||
|
self.default_max_lr = None
|
||||||
|
self.max_lrs = list(max_lr)
|
||||||
|
else:
|
||||||
|
self.default_max_lr = max_lr
|
||||||
|
self.max_lrs = [max_lr] * len(optimizer.param_groups)
|
||||||
|
|
||||||
|
def step(self, metrics, prob: float, epoch=None):
|
||||||
|
if prob > self.last_prob:
|
||||||
|
self.best = metrics
|
||||||
|
self.num_bad_epochs = 0
|
||||||
|
self.last_prob = prob
|
||||||
|
self._increase_lr()
|
||||||
|
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
|
||||||
|
else:
|
||||||
|
return super().step(metrics, epoch)
|
||||||
|
|
||||||
|
def _increase_lr(self):
|
||||||
|
for i, param_group in enumerate(self.optimizer.param_groups):
|
||||||
|
old_lr = float(param_group["lr"])
|
||||||
|
new_lr = min(old_lr * self.increase_factor, self.max_lrs[i])
|
||||||
|
if new_lr - old_lr > self.eps:
|
||||||
|
param_group["lr"] = new_lr
|
||||||
Loading…
Reference in New Issue
Block a user