From 310b3fae80d192f282287168205d80c5d0df3241 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Thu, 11 Dec 2025 22:55:55 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E9=80=82=E9=85=8D=E6=96=B0=E7=9A=84?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E9=9B=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/common/cond.py | 4 ++-- ginka/train_wgan.py | 41 +++++++++++++++++++++++++++++------------ train.sh | 2 +- 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/ginka/common/cond.py b/ginka/common/cond.py index 485bad5..1b133fa 100644 --- a/ginka/common/cond.py +++ b/ginka/common/cond.py @@ -25,10 +25,10 @@ class ConditionEncoder(nn.Module): ) def forward(self, tag, val, stage): - tag = self.tag_embed(tag) + # tag = self.tag_embed(tag) val = self.val_embed(val) stage = self.stage_embed(stage) - feat = torch.stack([tag, val, stage], dim=1) + feat = torch.stack([val, stage], dim=1) feat = self.encoder(feat) feat = torch.mean(feat, dim=1) feat = self.fusion(feat) diff --git a/ginka/train_wgan.py b/ginka/train_wgan.py index cce18e1..9f982a1 100644 --- a/ginka/train_wgan.py +++ b/ginka/train_wgan.py @@ -128,8 +128,8 @@ def train(): optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9)) optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-4, betas=(0.0, 0.9)) - scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=100, T_mult=1) - scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=100, T_mult=1) + scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2) + scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=10, T_mult=2) criterion = WGANGinkaLoss() @@ -196,6 +196,8 @@ def train(): dis_total = torch.Tensor([0]).to(device) loss_ce_total = torch.Tensor([0]).to(device) + iters = 0 + for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): rand = batch["rand"].to(device) real0 = batch["real0"].to(device) @@ -280,12 +282,27 @@ def train(): optimizer_ginka.step() loss_total_ginka += loss_g.detach() + iters += 1 + + if iters % 50 == 0: + avg_loss_ginka = loss_total_ginka.item() / iters / g_steps + avg_loss_minamo = loss_total_minamo.item() / iters / c_steps + avg_loss_ce = loss_ce_total.item() / iters / g_steps + avg_dis = dis_total.item() / iters / c_steps + tqdm.write( + f"[Iters {iters} {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + + f"E: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " + + f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " + + f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | " + + f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}" + ) + avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps avg_loss_ce = loss_ce_total.item() / len(dataloader) / g_steps avg_dis = dis_total.item() / len(dataloader) / c_steps tqdm.write( - f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + + f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + f"E: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " + f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " + f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | " + @@ -368,14 +385,14 @@ def train(): # 训练流程控制 - if train_stage >= 2: - # train_stage = 4 - if (epoch + 1) % 100 == 5: - train_stage = 3 - elif (epoch + 1) % 100 == 20: - train_stage = 4 - elif (epoch + 1) % 100 == 0: - train_stage = 2 + # if train_stage >= 2: + # # train_stage = 4 + # if (epoch + 1) % 100 == 5: + # train_stage = 3 + # elif (epoch + 1) % 100 == 20: + # train_stage = 4 + # elif (epoch + 1) % 100 == 0: + # train_stage = 2 if train_stage == 1: if (mask_ratio < 0.3 and stage_epoch >= first_curr) or \ @@ -385,7 +402,7 @@ def train(): stage_epoch = 0 if mask_ratio >= 0.8: - train_stage = 2 + train_stage = 4 stage_epoch += 1 total_epoch += 1 diff --git a/train.sh b/train.sh index 289e32e..d13229f 100644 --- a/train.sh +++ b/train.sh @@ -1,4 +1,4 @@ # 从头训练 -python3 -u -m ginka.train_wgan --epochs 300 >> output.log +python3 -u -m ginka.train_wgan --epochs 20 --curr_epoch 1 --checkpoint 1 >> output.log # 接续训练 python3 -u -m ginka.train_wgan --resume true --epochs 300 --state_ginka "result/wgan/ginka-100.pth" --state_minamo "result/wgan/minamo-100.pth" >> output.log