diff --git a/ginka/generator/loss.py b/ginka/generator/loss.py index 1220cbb..d3f0cfb 100644 --- a/ginka/generator/loss.py +++ b/ginka/generator/loss.py @@ -408,4 +408,4 @@ class RNNGinkaLoss: def rnn_loss(self, fake, target): target = F.one_hot(target, num_classes=32).float() - return F.cross_entropy(fake, target) + return F.cross_entropy(fake, target, label_smoothing=0.05) diff --git a/ginka/train_rnn.py b/ginka/train_rnn.py index 8807c08..9bd7a84 100644 --- a/ginka/train_rnn.py +++ b/ginka/train_rnn.py @@ -83,7 +83,7 @@ def train(): dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE) - optimizer_ginka = optim.Adam(ginka_rnn.parameters(), lr=1e-3) + optimizer_ginka = optim.AdamW(ginka_rnn.parameters(), lr=1e-4, weight_decay=1e-4) scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2) criterion = RNNGinkaLoss() @@ -119,6 +119,7 @@ def train(): loss = criterion.rnn_loss(fake_logits, target_map) loss.backward() + torch.nn.utils.clip_grad_norm_(ginka_rnn.parameters(), max_norm=1.0) optimizer_ginka.step() loss_total_ginka += loss.detach() @@ -133,9 +134,9 @@ def train(): # f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}" # ) - avg_loss_ginka = loss_total_ginka.item() / iters + avg_loss_ginka = loss_total_ginka.item() / len(dataloader) tqdm.write( - f"[Epoch {epoch} {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + + f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + f"E: {epoch + 1} | Loss: {avg_loss_ginka:.6f} | " + f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}" ) @@ -170,7 +171,7 @@ def train(): avg_loss_val = val_loss_total.item() / len(dataloader_val) tqdm.write( - f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch} | " + + f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch + 1} | " + f"Loss: {avg_loss_val:.6f}" )