mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
fix: 修改超参数
This commit is contained in:
parent
ae53194694
commit
d42c2eee43
@ -408,4 +408,4 @@ class RNNGinkaLoss:
|
|||||||
|
|
||||||
def rnn_loss(self, fake, target):
|
def rnn_loss(self, fake, target):
|
||||||
target = F.one_hot(target, num_classes=32).float()
|
target = F.one_hot(target, num_classes=32).float()
|
||||||
return F.cross_entropy(fake, target)
|
return F.cross_entropy(fake, target, label_smoothing=0.05)
|
||||||
|
|||||||
@ -83,7 +83,7 @@ def train():
|
|||||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE)
|
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE)
|
||||||
|
|
||||||
optimizer_ginka = optim.Adam(ginka_rnn.parameters(), lr=1e-3)
|
optimizer_ginka = optim.AdamW(ginka_rnn.parameters(), lr=1e-4, weight_decay=1e-4)
|
||||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2)
|
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2)
|
||||||
|
|
||||||
criterion = RNNGinkaLoss()
|
criterion = RNNGinkaLoss()
|
||||||
@ -119,6 +119,7 @@ def train():
|
|||||||
loss = criterion.rnn_loss(fake_logits, target_map)
|
loss = criterion.rnn_loss(fake_logits, target_map)
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(ginka_rnn.parameters(), max_norm=1.0)
|
||||||
optimizer_ginka.step()
|
optimizer_ginka.step()
|
||||||
loss_total_ginka += loss.detach()
|
loss_total_ginka += loss.detach()
|
||||||
|
|
||||||
@ -133,9 +134,9 @@ def train():
|
|||||||
# f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
|
# f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
|
||||||
# )
|
# )
|
||||||
|
|
||||||
avg_loss_ginka = loss_total_ginka.item() / iters
|
avg_loss_ginka = loss_total_ginka.item() / len(dataloader)
|
||||||
tqdm.write(
|
tqdm.write(
|
||||||
f"[Epoch {epoch} {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
||||||
f"E: {epoch + 1} | Loss: {avg_loss_ginka:.6f} | " +
|
f"E: {epoch + 1} | Loss: {avg_loss_ginka:.6f} | " +
|
||||||
f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
|
f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
|
||||||
)
|
)
|
||||||
@ -170,7 +171,7 @@ def train():
|
|||||||
|
|
||||||
avg_loss_val = val_loss_total.item() / len(dataloader_val)
|
avg_loss_val = val_loss_total.item() / len(dataloader_val)
|
||||||
tqdm.write(
|
tqdm.write(
|
||||||
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch} | " +
|
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch + 1} | " +
|
||||||
f"Loss: {avg_loss_val:.6f}"
|
f"Loss: {avg_loss_val:.6f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user