perf: 调优部分超参数

This commit is contained in:
unanmed 2025-04-06 21:19:18 +08:00
parent 29cfb4d029
commit 8130296e1f
2 changed files with 12 additions and 7 deletions

View File

@ -311,7 +311,7 @@ def js_divergence(P, Q, epsilon=1e-10):
return js.mean() # 标量
class WGANGinkaLoss:
def __init__(self, lambda_gp=10, weight=[0.7, 0.2, 0.1], diversity_lamda=0):
def __init__(self, lambda_gp=20, weight=[0.7, 0.2, 0.1], diversity_lamda=0):
self.lambda_gp = lambda_gp # 梯度惩罚系数
self.weight = weight
self.diversity_lamda = diversity_lamda

View File

@ -38,7 +38,7 @@ def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
c_steps = 1
g_steps = 3
g_steps = 4
args = parse_arguments()
@ -103,9 +103,9 @@ def train():
for _ in range(g_steps):
z1 = torch.randn(batch_size, 1024, device=device)
z2 = torch.randn(batch_size, 1024, device=device)
fake_softmax1, fakse_softmax2 = ginka(z1), ginka(z2)
fake_softmax1, fake_softmax2 = ginka(z1), ginka(z2)
loss_g = criterion.generator_loss(minamo, fake_softmax1, fakse_softmax2)
loss_g = criterion.generator_loss(minamo, fake_softmax1, fake_softmax2)
loss_g.backward()
optimizer_ginka.step()
@ -120,14 +120,19 @@ def train():
)
if avg_dis < -9:
g_steps = 7
g_steps = 21
elif avg_dis < -6:
g_steps = 5
g_steps = 14
elif avg_dis < -3:
g_steps = 3
g_steps = 7
else:
g_steps = 1
if avg_dis > 3:
c_steps = 3
else:
c_steps = 1
# 每五轮输出一次图片,并保存检查点
if (epoch + 1) % 5 == 0:
# 输出 20 张图片,每批次 4 张,一共五批