mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
perf: 调优部分超参数
This commit is contained in:
parent
29cfb4d029
commit
8130296e1f
@ -311,7 +311,7 @@ def js_divergence(P, Q, epsilon=1e-10):
|
||||
return js.mean() # 标量
|
||||
|
||||
class WGANGinkaLoss:
|
||||
def __init__(self, lambda_gp=10, weight=[0.7, 0.2, 0.1], diversity_lamda=0):
|
||||
def __init__(self, lambda_gp=20, weight=[0.7, 0.2, 0.1], diversity_lamda=0):
|
||||
self.lambda_gp = lambda_gp # 梯度惩罚系数
|
||||
self.weight = weight
|
||||
self.diversity_lamda = diversity_lamda
|
||||
|
||||
@ -38,7 +38,7 @@ def train():
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||
|
||||
c_steps = 1
|
||||
g_steps = 3
|
||||
g_steps = 4
|
||||
|
||||
args = parse_arguments()
|
||||
|
||||
@ -103,9 +103,9 @@ def train():
|
||||
for _ in range(g_steps):
|
||||
z1 = torch.randn(batch_size, 1024, device=device)
|
||||
z2 = torch.randn(batch_size, 1024, device=device)
|
||||
fake_softmax1, fakse_softmax2 = ginka(z1), ginka(z2)
|
||||
fake_softmax1, fake_softmax2 = ginka(z1), ginka(z2)
|
||||
|
||||
loss_g = criterion.generator_loss(minamo, fake_softmax1, fakse_softmax2)
|
||||
loss_g = criterion.generator_loss(minamo, fake_softmax1, fake_softmax2)
|
||||
loss_g.backward()
|
||||
optimizer_ginka.step()
|
||||
|
||||
@ -120,14 +120,19 @@ def train():
|
||||
)
|
||||
|
||||
if avg_dis < -9:
|
||||
g_steps = 7
|
||||
g_steps = 21
|
||||
elif avg_dis < -6:
|
||||
g_steps = 5
|
||||
g_steps = 14
|
||||
elif avg_dis < -3:
|
||||
g_steps = 3
|
||||
g_steps = 7
|
||||
else:
|
||||
g_steps = 1
|
||||
|
||||
if avg_dis > 3:
|
||||
c_steps = 3
|
||||
else:
|
||||
c_steps = 1
|
||||
|
||||
# 每五轮输出一次图片,并保存检查点
|
||||
if (epoch + 1) % 5 == 0:
|
||||
# 输出 20 张图片,每批次 4 张,一共五批
|
||||
|
||||
Loading…
Reference in New Issue
Block a user