fix: 修复损失值计算

This commit is contained in:
unanmed 2025-04-16 21:54:53 +08:00
parent 9b5abf177c
commit ef0b7ffba2
3 changed files with 52 additions and 31 deletions

View File

@ -295,7 +295,10 @@ def interpolate_graph_features(real_graph, fake_graph, epsilon=0.5):
return Data(x=x_interp, edge_index=edge_index_interp, edge_attr=edge_attr_interp)
def js_divergence(p, q, eps=1e-8):
def js_divergence(p, q, eps=1e-6, softmax=False):
if softmax:
p = F.softmax(p, dim=1)
q = F.softmax(q, dim=1)
# softmax 后变成概率分布
m = 0.5 * (p + q)
@ -304,8 +307,6 @@ def js_divergence(p, q, eps=1e-8):
log_q = torch.log(q + eps)
log_m = torch.log(m + eps)
nn.KLDivLoss
kl_pm = F.kl_div(log_p, log_m, reduction='batchmean', log_target=True) # KL(p || m)
kl_qm = F.kl_div(log_q, log_m, reduction='batchmean', log_target=True) # KL(q || m)
@ -335,7 +336,7 @@ def immutable_penalty_loss(
return penalty
class WGANGinkaLoss:
def __init__(self, lambda_gp=100, weight=[1, 0.4, 25, 0.2, 0.2, 0.02]):
def __init__(self, lambda_gp=100, weight=[1, 0.4, 25, 0.2, 0.2, 0.01]):
# weight: 判别器损失L1 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失
self.lambda_gp = lambda_gp # 梯度惩罚系数
self.weight = weight
@ -380,6 +381,7 @@ class WGANGinkaLoss:
self, critic, stage: int, real_data: torch.Tensor, fake_data: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
""" 判别器损失函数 """
fake_data = F.softmax(fake_data, dim=1)
real_graph = batch_convert_soft_map_to_graph(real_data)
fake_graph = batch_convert_soft_map_to_graph(fake_data)
real_scores, _, _ = critic(real_data, real_graph, stage)
@ -393,24 +395,16 @@ class WGANGinkaLoss:
return total_loss, d_loss
def calculate_similarity_one(self, map1, map2):
topo1 = build_topological_graph(map1)
topo2 = build_topological_graph(map2)
vis_sim = calculate_visual_similarity(map1, map2)
topo_sim = overall_similarity(topo1, topo2)
return vis_sim, topo_sim
def generator_loss(self, critic, stage, mask_ratio, real, fake: torch.Tensor, input) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
""" 生成器损失函数 """
fake_graph = batch_convert_soft_map_to_graph(fake)
probs_fake = F.softmax(fake, dim=1)
fake_graph = batch_convert_soft_map_to_graph(probs_fake)
fake_scores, _, _ = critic(fake, fake_graph, stage)
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
minamo_loss = -torch.mean(fake_scores)
ce_loss = F.l1_loss(fake, real)
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
ce_loss = F.cross_entropy(fake, real)
immutable_loss = immutable_penalty_loss(probs_fake, F.softmax(input, dim=1), STAGE_ALLOWED[stage])
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
fake_a, fake_b = fake.chunk(2, dim=0)
@ -419,12 +413,12 @@ class WGANGinkaLoss:
ce_loss * self.weight[1] * (1 - mask_ratio), # 蒙版越大,交叉熵损失权重越小
immutable_loss * self.weight[2],
constraint_loss * self.weight[3],
-js_divergence(fake_a, fake_b) * self.weight[5],
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
]
if stage == 1:
# 第一个阶段检查入口存在性
entrance_loss = entrance_constraint_loss(fake)
entrance_loss = entrance_constraint_loss(probs_fake)
losses.append(entrance_loss * self.weight[4])
# print(-js_divergence(fake_a, fake_b).item())
@ -432,12 +426,36 @@ class WGANGinkaLoss:
return sum(losses), minamo_loss, ce_loss, immutable_loss
def generator_loss_total(self, critic, stage, fake) -> torch.Tensor:
fake_graph = batch_convert_soft_map_to_graph(fake)
probs_fake = F.softmax(fake, dim=1)
fake_graph = batch_convert_soft_map_to_graph(probs_fake)
fake_scores, _, _ = critic(fake, fake_graph, stage)
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
minamo_loss = -torch.mean(fake_scores)
immutable_loss = immutable_penalty_loss(fake, fake, STAGE_ALLOWED[stage])
constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
fake_a, fake_b = fake.chunk(2, dim=0)
losses = [
minamo_loss * self.weight[0],
constraint_loss * self.weight[3],
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
]
if stage == 1:
# 第一个阶段检查入口存在性
entrance_loss = entrance_constraint_loss(probs_fake)
losses.append(entrance_loss * self.weight[4])
return sum(losses)
def generator_loss_total_with_input(self, critic, stage, fake, input) -> torch.Tensor:
probs_fake = F.softmax(fake, dim=1)
fake_graph = batch_convert_soft_map_to_graph(probs_fake)
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
minamo_loss = -torch.mean(fake_scores)
immutable_loss = immutable_penalty_loss(probs_fake, F.softmax(input, dim=1), STAGE_ALLOWED[stage])
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
fake_a, fake_b = fake.chunk(2, dim=0)
@ -445,12 +463,12 @@ class WGANGinkaLoss:
minamo_loss * self.weight[0],
immutable_loss * self.weight[2],
constraint_loss * self.weight[3],
-js_divergence(fake_a, fake_b) * self.weight[5],
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
]
if stage == 1:
# 第一个阶段检查入口存在性
entrance_loss = entrance_constraint_loss(fake)
entrance_loss = entrance_constraint_loss(probs_fake)
losses.append(entrance_loss * self.weight[4])
return sum(losses)

View File

@ -27,7 +27,7 @@ class GinkaModel(nn.Module):
x = self.input(x)
x = self.unet(x)
x = self.output(x, stage)
return F.softmax(x, dim=1)
return x
# 检查显存占用
if __name__ == "__main__":

View File

@ -201,9 +201,12 @@ def train():
elif train_stage == 3 or train_stage == 4:
fake1, fake2, fake3 = gen_total(ginka, masked1, True, False)
loss_g1 = criterion.generator_loss_total(minamo, 1, fake1)
loss_g2 = criterion.generator_loss_total(minamo, 2, fake2)
loss_g3 = criterion.generator_loss_total(minamo, 3, fake3)
if train_stage == 3:
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, masked1)
else:
loss_g1 = criterion.generator_loss_total(minamo, 1, fake1)
loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1)
loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2)
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0
loss_g.backward()
@ -221,7 +224,7 @@ def train():
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | R: {random_ratio:.1f}"
)
if avg_loss_ce < 0.1:
if avg_loss_ce < 0.5:
low_loss_epochs += 1
else:
low_loss_epochs = 0