From d42c2eee4342f0a6de172eedb292bf2d97aa6096 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Sat, 13 Dec 2025 19:46:09 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E6=94=B9=E8=B6=85=E5=8F=82?= =?UTF-8?q?=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/generator/loss.py | 2 +- ginka/train_rnn.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/ginka/generator/loss.py b/ginka/generator/loss.py index 1220cbb..d3f0cfb 100644 --- a/ginka/generator/loss.py +++ b/ginka/generator/loss.py @@ -408,4 +408,4 @@ class RNNGinkaLoss: def rnn_loss(self, fake, target): target = F.one_hot(target, num_classes=32).float() - return F.cross_entropy(fake, target) + return F.cross_entropy(fake, target, label_smoothing=0.05) diff --git a/ginka/train_rnn.py b/ginka/train_rnn.py index 8807c08..9bd7a84 100644 --- a/ginka/train_rnn.py +++ b/ginka/train_rnn.py @@ -83,7 +83,7 @@ def train(): dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE) - optimizer_ginka = optim.Adam(ginka_rnn.parameters(), lr=1e-3) + optimizer_ginka = optim.AdamW(ginka_rnn.parameters(), lr=1e-4, weight_decay=1e-4) scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2) criterion = RNNGinkaLoss() @@ -119,6 +119,7 @@ def train(): loss = criterion.rnn_loss(fake_logits, target_map) loss.backward() + torch.nn.utils.clip_grad_norm_(ginka_rnn.parameters(), max_norm=1.0) optimizer_ginka.step() loss_total_ginka += loss.detach() @@ -133,9 +134,9 @@ def train(): # f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}" # ) - avg_loss_ginka = loss_total_ginka.item() / iters + avg_loss_ginka = loss_total_ginka.item() / len(dataloader) tqdm.write( - f"[Epoch {epoch} {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + + f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + f"E: {epoch + 1} | Loss: {avg_loss_ginka:.6f} | " + f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}" ) @@ -170,7 +171,7 @@ def train(): avg_loss_val = val_loss_total.item() / len(dataloader_val) tqdm.write( - f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch} | " + + f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch + 1} | " + f"Loss: {avg_loss_val:.6f}" )