mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
fix: 修改超参数
This commit is contained in:
parent
ae53194694
commit
d42c2eee43
@ -408,4 +408,4 @@ class RNNGinkaLoss:
|
||||
|
||||
def rnn_loss(self, fake, target):
|
||||
target = F.one_hot(target, num_classes=32).float()
|
||||
return F.cross_entropy(fake, target)
|
||||
return F.cross_entropy(fake, target, label_smoothing=0.05)
|
||||
|
||||
@ -83,7 +83,7 @@ def train():
|
||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE)
|
||||
|
||||
optimizer_ginka = optim.Adam(ginka_rnn.parameters(), lr=1e-3)
|
||||
optimizer_ginka = optim.AdamW(ginka_rnn.parameters(), lr=1e-4, weight_decay=1e-4)
|
||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2)
|
||||
|
||||
criterion = RNNGinkaLoss()
|
||||
@ -119,6 +119,7 @@ def train():
|
||||
loss = criterion.rnn_loss(fake_logits, target_map)
|
||||
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(ginka_rnn.parameters(), max_norm=1.0)
|
||||
optimizer_ginka.step()
|
||||
loss_total_ginka += loss.detach()
|
||||
|
||||
@ -133,9 +134,9 @@ def train():
|
||||
# f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
|
||||
# )
|
||||
|
||||
avg_loss_ginka = loss_total_ginka.item() / iters
|
||||
avg_loss_ginka = loss_total_ginka.item() / len(dataloader)
|
||||
tqdm.write(
|
||||
f"[Epoch {epoch} {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
||||
f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
||||
f"E: {epoch + 1} | Loss: {avg_loss_ginka:.6f} | " +
|
||||
f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
|
||||
)
|
||||
@ -170,7 +171,7 @@ def train():
|
||||
|
||||
avg_loss_val = val_loss_total.item() / len(dataloader_val)
|
||||
tqdm.write(
|
||||
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch} | " +
|
||||
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch + 1} | " +
|
||||
f"Loss: {avg_loss_val:.6f}"
|
||||
)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user