mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-19 08:21:11 +08:00
chore: 调整kl参数
This commit is contained in:
parent
dd9d8a3713
commit
45cfa3b611
@ -54,7 +54,7 @@ from shared.image import matrix_to_image_cv
|
||||
|
||||
BATCH_SIZE = 128
|
||||
LATENT_DIM = 48
|
||||
KL_BETA = 0.05
|
||||
KL_BETA = 0.1
|
||||
|
||||
device = torch.device(
|
||||
"cuda:1" if torch.cuda.is_available()
|
||||
@ -145,21 +145,22 @@ def train():
|
||||
)
|
||||
|
||||
# 验证集
|
||||
with torch.no_grad():
|
||||
val_loss_total = torch.Tensor([0]).to(device)
|
||||
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
||||
target_map = batch["target_map"].to(device)
|
||||
# with torch.no_grad():
|
||||
# val_loss_total = torch.Tensor([0]).to(device)
|
||||
# for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
||||
# target_map = batch["target_map"].to(device)
|
||||
|
||||
fake_logits, mu, logvar = vae(target_map, 1 - gt_prob)
|
||||
# fake_logits, mu, logvar = vae(target_map, 1 - gt_prob)
|
||||
|
||||
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
|
||||
val_loss_total += loss.detach()
|
||||
# loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
|
||||
# val_loss_total += loss.detach()
|
||||
|
||||
avg_loss_val = val_loss_total.item() / len(dataloader_val)
|
||||
if avg_loss_val < 0.5 and gt_prob > 0:
|
||||
# avg_loss_val = val_loss_total.item() / len(dataloader_val)
|
||||
# 先使用训练集的损失值,因为过拟合比较严重,后续再想办法
|
||||
if avg_loss < 0.5 and gt_prob > 0:
|
||||
gt_prob -= 0.01
|
||||
|
||||
scheduler_ginka.step(avg_loss_val)
|
||||
scheduler_ginka.step(avg_loss)
|
||||
|
||||
# 每若干轮输出一次图片,并保存检查点
|
||||
if (epoch + 1) % args.checkpoint == 0:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user