mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-22 11:01:12 +08:00
chore: 调整调度方式
This commit is contained in:
parent
1352d64a50
commit
a07d2cf960
@ -56,6 +56,8 @@ from shared.image import matrix_to_image_cv
|
|||||||
BATCH_SIZE = 128
|
BATCH_SIZE = 128
|
||||||
LATENT_DIM = 48
|
LATENT_DIM = 48
|
||||||
KL_BETA = 0.1
|
KL_BETA = 0.1
|
||||||
|
SELF_GATE = 0.5
|
||||||
|
GATE_EPOCH = 5
|
||||||
|
|
||||||
device = torch.device(
|
device = torch.device(
|
||||||
"cuda:1" if torch.cuda.is_available()
|
"cuda:1" if torch.cuda.is_available()
|
||||||
@ -101,6 +103,7 @@ def train():
|
|||||||
criterion = VAELoss()
|
criterion = VAELoss()
|
||||||
|
|
||||||
self_prob = 0
|
self_prob = 0
|
||||||
|
prob_epochs = 0
|
||||||
|
|
||||||
# 用于生成图片
|
# 用于生成图片
|
||||||
tile_dict = dict()
|
tile_dict = dict()
|
||||||
@ -153,16 +156,21 @@ def train():
|
|||||||
# val_loss_total = torch.Tensor([0]).to(device)
|
# val_loss_total = torch.Tensor([0]).to(device)
|
||||||
# for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
# for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
||||||
# target_map = batch["target_map"].to(device)
|
# target_map = batch["target_map"].to(device)
|
||||||
|
|
||||||
# fake_logits, mu, logvar = vae(target_map, 1 - gt_prob)
|
# fake_logits, mu, logvar = vae(target_map, 1 - gt_prob)
|
||||||
|
|
||||||
# loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
|
# loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
|
||||||
# val_loss_total += loss.detach()
|
# val_loss_total += loss.detach()
|
||||||
|
|
||||||
# avg_loss_val = val_loss_total.item() / len(dataloader_val)
|
# avg_loss_val = val_loss_total.item() / len(dataloader_val)
|
||||||
|
|
||||||
# 先使用训练集的损失值,因为过拟合比较严重,后续再想办法
|
# 先使用训练集的损失值,因为过拟合比较严重,后续再想办法
|
||||||
if avg_loss < 0.5 and self_prob < 1:
|
if avg_loss < SELF_GATE:
|
||||||
|
gate_epochs += 1
|
||||||
|
|
||||||
|
if gate_epochs >= GATE_EPOCH and self_prob < 1:
|
||||||
self_prob += 0.01
|
self_prob += 0.01
|
||||||
|
gate_epochs = 0
|
||||||
|
|
||||||
scheduler_ginka.step(avg_loss, self_prob)
|
scheduler_ginka.step(avg_loss, self_prob)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user