mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-20 17:22:41 +08:00
chore: 调整部分参数
This commit is contained in:
parent
a07d2cf960
commit
becf625bdb
@ -58,6 +58,7 @@ LATENT_DIM = 48
|
|||||||
KL_BETA = 0.1
|
KL_BETA = 0.1
|
||||||
SELF_GATE = 0.5
|
SELF_GATE = 0.5
|
||||||
GATE_EPOCH = 5
|
GATE_EPOCH = 5
|
||||||
|
VAL_BATCH_DIVIDER = 1
|
||||||
|
|
||||||
device = torch.device(
|
device = torch.device(
|
||||||
"cuda:1" if torch.cuda.is_available()
|
"cuda:1" if torch.cuda.is_available()
|
||||||
@ -92,12 +93,12 @@ def train():
|
|||||||
dataset = GinkaRNNDataset(args.train, device)
|
dataset = GinkaRNNDataset(args.train, device)
|
||||||
dataset_val = GinkaRNNDataset(args.validate, device)
|
dataset_val = GinkaRNNDataset(args.validate, device)
|
||||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 64, shuffle=True)
|
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True)
|
||||||
|
|
||||||
optimizer_ginka = optim.AdamW(vae.parameters(), lr=2e-4, weight_decay=1e-4)
|
optimizer_ginka = optim.AdamW(vae.parameters(), lr=2e-4, weight_decay=1e-4)
|
||||||
# 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习
|
# 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习
|
||||||
scheduler_ginka = VAEScheduler(
|
scheduler_ginka = VAEScheduler(
|
||||||
optimizer_ginka, factor=0.9, increase_factor=1.1, patience=10, max_lr=2e-4, min_lr=1e-6
|
optimizer_ginka, factor=0.9, increase_factor=2, patience=10, max_lr=2e-4, min_lr=1e-6
|
||||||
)
|
)
|
||||||
|
|
||||||
criterion = VAELoss()
|
criterion = VAELoss()
|
||||||
@ -166,11 +167,11 @@ def train():
|
|||||||
|
|
||||||
# 先使用训练集的损失值,因为过拟合比较严重,后续再想办法
|
# 先使用训练集的损失值,因为过拟合比较严重,后续再想办法
|
||||||
if avg_loss < SELF_GATE:
|
if avg_loss < SELF_GATE:
|
||||||
gate_epochs += 1
|
prob_epochs += 1
|
||||||
|
|
||||||
if gate_epochs >= GATE_EPOCH and self_prob < 1:
|
if prob_epochs >= GATE_EPOCH and self_prob < 1:
|
||||||
self_prob += 0.01
|
self_prob += 0.01
|
||||||
gate_epochs = 0
|
prob_epochs = 0
|
||||||
|
|
||||||
scheduler_ginka.step(avg_loss, self_prob)
|
scheduler_ginka.step(avg_loss, self_prob)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user