From 8130296e1f1733f454108398416c3bea44201f86 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Sun, 6 Apr 2025 21:19:18 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E8=B0=83=E4=BC=98=E9=83=A8=E5=88=86?= =?UTF-8?q?=E8=B6=85=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/model/loss.py | 2 +- ginka/train_wgan.py | 17 +++++++++++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/ginka/model/loss.py b/ginka/model/loss.py index d50842d..7bfb05c 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -311,7 +311,7 @@ def js_divergence(P, Q, epsilon=1e-10): return js.mean() # 标量 class WGANGinkaLoss: - def __init__(self, lambda_gp=10, weight=[0.7, 0.2, 0.1], diversity_lamda=0): + def __init__(self, lambda_gp=20, weight=[0.7, 0.2, 0.1], diversity_lamda=0): self.lambda_gp = lambda_gp # 梯度惩罚系数 self.weight = weight self.diversity_lamda = diversity_lamda diff --git a/ginka/train_wgan.py b/ginka/train_wgan.py index ac70e9c..8093946 100644 --- a/ginka/train_wgan.py +++ b/ginka/train_wgan.py @@ -38,7 +38,7 @@ def train(): print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") c_steps = 1 - g_steps = 3 + g_steps = 4 args = parse_arguments() @@ -103,9 +103,9 @@ def train(): for _ in range(g_steps): z1 = torch.randn(batch_size, 1024, device=device) z2 = torch.randn(batch_size, 1024, device=device) - fake_softmax1, fakse_softmax2 = ginka(z1), ginka(z2) + fake_softmax1, fake_softmax2 = ginka(z1), ginka(z2) - loss_g = criterion.generator_loss(minamo, fake_softmax1, fakse_softmax2) + loss_g = criterion.generator_loss(minamo, fake_softmax1, fake_softmax2) loss_g.backward() optimizer_ginka.step() @@ -120,14 +120,19 @@ def train(): ) if avg_dis < -9: - g_steps = 7 + g_steps = 21 elif avg_dis < -6: - g_steps = 5 + g_steps = 14 elif avg_dis < -3: - g_steps = 3 + g_steps = 7 else: g_steps = 1 + if avg_dis > 3: + c_steps = 3 + else: + c_steps = 1 + # 每五轮输出一次图片,并保存检查点 if (epoch + 1) % 5 == 0: # 输出 20 张图片,每批次 4 张,一共五批