mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
fix: 修复损失值计算
This commit is contained in:
parent
9b5abf177c
commit
ef0b7ffba2
@ -295,7 +295,10 @@ def interpolate_graph_features(real_graph, fake_graph, epsilon=0.5):
|
||||
|
||||
return Data(x=x_interp, edge_index=edge_index_interp, edge_attr=edge_attr_interp)
|
||||
|
||||
def js_divergence(p, q, eps=1e-8):
|
||||
def js_divergence(p, q, eps=1e-6, softmax=False):
|
||||
if softmax:
|
||||
p = F.softmax(p, dim=1)
|
||||
q = F.softmax(q, dim=1)
|
||||
# softmax 后变成概率分布
|
||||
m = 0.5 * (p + q)
|
||||
|
||||
@ -303,8 +306,6 @@ def js_divergence(p, q, eps=1e-8):
|
||||
log_p = torch.log(p + eps)
|
||||
log_q = torch.log(q + eps)
|
||||
log_m = torch.log(m + eps)
|
||||
|
||||
nn.KLDivLoss
|
||||
|
||||
kl_pm = F.kl_div(log_p, log_m, reduction='batchmean', log_target=True) # KL(p || m)
|
||||
kl_qm = F.kl_div(log_q, log_m, reduction='batchmean', log_target=True) # KL(q || m)
|
||||
@ -335,7 +336,7 @@ def immutable_penalty_loss(
|
||||
return penalty
|
||||
|
||||
class WGANGinkaLoss:
|
||||
def __init__(self, lambda_gp=100, weight=[1, 0.4, 25, 0.2, 0.2, 0.02]):
|
||||
def __init__(self, lambda_gp=100, weight=[1, 0.4, 25, 0.2, 0.2, 0.01]):
|
||||
# weight: 判别器损失,L1 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失
|
||||
self.lambda_gp = lambda_gp # 梯度惩罚系数
|
||||
self.weight = weight
|
||||
@ -380,6 +381,7 @@ class WGANGinkaLoss:
|
||||
self, critic, stage: int, real_data: torch.Tensor, fake_data: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
""" 判别器损失函数 """
|
||||
fake_data = F.softmax(fake_data, dim=1)
|
||||
real_graph = batch_convert_soft_map_to_graph(real_data)
|
||||
fake_graph = batch_convert_soft_map_to_graph(fake_data)
|
||||
real_scores, _, _ = critic(real_data, real_graph, stage)
|
||||
@ -392,25 +394,17 @@ class WGANGinkaLoss:
|
||||
total_loss = d_loss + self.lambda_gp * grad_loss
|
||||
|
||||
return total_loss, d_loss
|
||||
|
||||
def calculate_similarity_one(self, map1, map2):
|
||||
topo1 = build_topological_graph(map1)
|
||||
topo2 = build_topological_graph(map2)
|
||||
|
||||
vis_sim = calculate_visual_similarity(map1, map2)
|
||||
topo_sim = overall_similarity(topo1, topo2)
|
||||
|
||||
return vis_sim, topo_sim
|
||||
|
||||
def generator_loss(self, critic, stage, mask_ratio, real, fake: torch.Tensor, input) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
""" 生成器损失函数 """
|
||||
fake_graph = batch_convert_soft_map_to_graph(fake)
|
||||
probs_fake = F.softmax(fake, dim=1)
|
||||
fake_graph = batch_convert_soft_map_to_graph(probs_fake)
|
||||
|
||||
fake_scores, _, _ = critic(fake, fake_graph, stage)
|
||||
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
ce_loss = F.l1_loss(fake, real)
|
||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
|
||||
constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
|
||||
ce_loss = F.cross_entropy(fake, real)
|
||||
immutable_loss = immutable_penalty_loss(probs_fake, F.softmax(input, dim=1), STAGE_ALLOWED[stage])
|
||||
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
|
||||
|
||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
|
||||
@ -419,12 +413,12 @@ class WGANGinkaLoss:
|
||||
ce_loss * self.weight[1] * (1 - mask_ratio), # 蒙版越大,交叉熵损失权重越小
|
||||
immutable_loss * self.weight[2],
|
||||
constraint_loss * self.weight[3],
|
||||
-js_divergence(fake_a, fake_b) * self.weight[5],
|
||||
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
|
||||
]
|
||||
|
||||
if stage == 1:
|
||||
# 第一个阶段检查入口存在性
|
||||
entrance_loss = entrance_constraint_loss(fake)
|
||||
entrance_loss = entrance_constraint_loss(probs_fake)
|
||||
losses.append(entrance_loss * self.weight[4])
|
||||
|
||||
# print(-js_divergence(fake_a, fake_b).item())
|
||||
@ -432,12 +426,36 @@ class WGANGinkaLoss:
|
||||
return sum(losses), minamo_loss, ce_loss, immutable_loss
|
||||
|
||||
def generator_loss_total(self, critic, stage, fake) -> torch.Tensor:
|
||||
fake_graph = batch_convert_soft_map_to_graph(fake)
|
||||
probs_fake = F.softmax(fake, dim=1)
|
||||
fake_graph = batch_convert_soft_map_to_graph(probs_fake)
|
||||
|
||||
fake_scores, _, _ = critic(fake, fake_graph, stage)
|
||||
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
immutable_loss = immutable_penalty_loss(fake, fake, STAGE_ALLOWED[stage])
|
||||
constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
|
||||
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
|
||||
|
||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
|
||||
losses = [
|
||||
minamo_loss * self.weight[0],
|
||||
constraint_loss * self.weight[3],
|
||||
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
|
||||
]
|
||||
|
||||
if stage == 1:
|
||||
# 第一个阶段检查入口存在性
|
||||
entrance_loss = entrance_constraint_loss(probs_fake)
|
||||
losses.append(entrance_loss * self.weight[4])
|
||||
|
||||
return sum(losses)
|
||||
|
||||
def generator_loss_total_with_input(self, critic, stage, fake, input) -> torch.Tensor:
|
||||
probs_fake = F.softmax(fake, dim=1)
|
||||
fake_graph = batch_convert_soft_map_to_graph(probs_fake)
|
||||
|
||||
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
immutable_loss = immutable_penalty_loss(probs_fake, F.softmax(input, dim=1), STAGE_ALLOWED[stage])
|
||||
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
|
||||
|
||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
|
||||
@ -445,12 +463,12 @@ class WGANGinkaLoss:
|
||||
minamo_loss * self.weight[0],
|
||||
immutable_loss * self.weight[2],
|
||||
constraint_loss * self.weight[3],
|
||||
-js_divergence(fake_a, fake_b) * self.weight[5],
|
||||
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
|
||||
]
|
||||
|
||||
if stage == 1:
|
||||
# 第一个阶段检查入口存在性
|
||||
entrance_loss = entrance_constraint_loss(fake)
|
||||
entrance_loss = entrance_constraint_loss(probs_fake)
|
||||
losses.append(entrance_loss * self.weight[4])
|
||||
|
||||
return sum(losses)
|
||||
|
||||
@ -27,7 +27,7 @@ class GinkaModel(nn.Module):
|
||||
x = self.input(x)
|
||||
x = self.unet(x)
|
||||
x = self.output(x, stage)
|
||||
return F.softmax(x, dim=1)
|
||||
return x
|
||||
|
||||
# 检查显存占用
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -201,9 +201,12 @@ def train():
|
||||
elif train_stage == 3 or train_stage == 4:
|
||||
fake1, fake2, fake3 = gen_total(ginka, masked1, True, False)
|
||||
|
||||
loss_g1 = criterion.generator_loss_total(minamo, 1, fake1)
|
||||
loss_g2 = criterion.generator_loss_total(minamo, 2, fake2)
|
||||
loss_g3 = criterion.generator_loss_total(minamo, 3, fake3)
|
||||
if train_stage == 3:
|
||||
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, masked1)
|
||||
else:
|
||||
loss_g1 = criterion.generator_loss_total(minamo, 1, fake1)
|
||||
loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1)
|
||||
loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2)
|
||||
|
||||
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0
|
||||
loss_g.backward()
|
||||
@ -221,7 +224,7 @@ def train():
|
||||
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | R: {random_ratio:.1f}"
|
||||
)
|
||||
|
||||
if avg_loss_ce < 0.1:
|
||||
if avg_loss_ce < 0.5:
|
||||
low_loss_epochs += 1
|
||||
else:
|
||||
low_loss_epochs = 0
|
||||
|
||||
Loading…
Reference in New Issue
Block a user