fix: 修改超参数

This commit is contained in:
unanmed 2025-12-13 19:46:09 +08:00
parent ae53194694
commit d42c2eee43
2 changed files with 6 additions and 5 deletions

View File

@ -408,4 +408,4 @@ class RNNGinkaLoss:
def rnn_loss(self, fake, target):
target = F.one_hot(target, num_classes=32).float()
return F.cross_entropy(fake, target)
return F.cross_entropy(fake, target, label_smoothing=0.05)

View File

@ -83,7 +83,7 @@ def train():
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE)
optimizer_ginka = optim.Adam(ginka_rnn.parameters(), lr=1e-3)
optimizer_ginka = optim.AdamW(ginka_rnn.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2)
criterion = RNNGinkaLoss()
@ -119,6 +119,7 @@ def train():
loss = criterion.rnn_loss(fake_logits, target_map)
loss.backward()
torch.nn.utils.clip_grad_norm_(ginka_rnn.parameters(), max_norm=1.0)
optimizer_ginka.step()
loss_total_ginka += loss.detach()
@ -133,9 +134,9 @@ def train():
# f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
# )
avg_loss_ginka = loss_total_ginka.item() / iters
avg_loss_ginka = loss_total_ginka.item() / len(dataloader)
tqdm.write(
f"[Epoch {epoch} {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
f"E: {epoch + 1} | Loss: {avg_loss_ginka:.6f} | " +
f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
)
@ -170,7 +171,7 @@ def train():
avg_loss_val = val_loss_total.item() / len(dataloader_val)
tqdm.write(
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch} | " +
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch + 1} | " +
f"Loss: {avg_loss_val:.6f}"
)