mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-21 02:11:13 +08:00
fix: 修复损失值计算
This commit is contained in:
parent
9b5abf177c
commit
ef0b7ffba2
@ -295,7 +295,10 @@ def interpolate_graph_features(real_graph, fake_graph, epsilon=0.5):
|
|||||||
|
|
||||||
return Data(x=x_interp, edge_index=edge_index_interp, edge_attr=edge_attr_interp)
|
return Data(x=x_interp, edge_index=edge_index_interp, edge_attr=edge_attr_interp)
|
||||||
|
|
||||||
def js_divergence(p, q, eps=1e-8):
|
def js_divergence(p, q, eps=1e-6, softmax=False):
|
||||||
|
if softmax:
|
||||||
|
p = F.softmax(p, dim=1)
|
||||||
|
q = F.softmax(q, dim=1)
|
||||||
# softmax 后变成概率分布
|
# softmax 后变成概率分布
|
||||||
m = 0.5 * (p + q)
|
m = 0.5 * (p + q)
|
||||||
|
|
||||||
@ -304,8 +307,6 @@ def js_divergence(p, q, eps=1e-8):
|
|||||||
log_q = torch.log(q + eps)
|
log_q = torch.log(q + eps)
|
||||||
log_m = torch.log(m + eps)
|
log_m = torch.log(m + eps)
|
||||||
|
|
||||||
nn.KLDivLoss
|
|
||||||
|
|
||||||
kl_pm = F.kl_div(log_p, log_m, reduction='batchmean', log_target=True) # KL(p || m)
|
kl_pm = F.kl_div(log_p, log_m, reduction='batchmean', log_target=True) # KL(p || m)
|
||||||
kl_qm = F.kl_div(log_q, log_m, reduction='batchmean', log_target=True) # KL(q || m)
|
kl_qm = F.kl_div(log_q, log_m, reduction='batchmean', log_target=True) # KL(q || m)
|
||||||
|
|
||||||
@ -335,7 +336,7 @@ def immutable_penalty_loss(
|
|||||||
return penalty
|
return penalty
|
||||||
|
|
||||||
class WGANGinkaLoss:
|
class WGANGinkaLoss:
|
||||||
def __init__(self, lambda_gp=100, weight=[1, 0.4, 25, 0.2, 0.2, 0.02]):
|
def __init__(self, lambda_gp=100, weight=[1, 0.4, 25, 0.2, 0.2, 0.01]):
|
||||||
# weight: 判别器损失,L1 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失
|
# weight: 判别器损失,L1 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失
|
||||||
self.lambda_gp = lambda_gp # 梯度惩罚系数
|
self.lambda_gp = lambda_gp # 梯度惩罚系数
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
@ -380,6 +381,7 @@ class WGANGinkaLoss:
|
|||||||
self, critic, stage: int, real_data: torch.Tensor, fake_data: torch.Tensor
|
self, critic, stage: int, real_data: torch.Tensor, fake_data: torch.Tensor
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
""" 判别器损失函数 """
|
""" 判别器损失函数 """
|
||||||
|
fake_data = F.softmax(fake_data, dim=1)
|
||||||
real_graph = batch_convert_soft_map_to_graph(real_data)
|
real_graph = batch_convert_soft_map_to_graph(real_data)
|
||||||
fake_graph = batch_convert_soft_map_to_graph(fake_data)
|
fake_graph = batch_convert_soft_map_to_graph(fake_data)
|
||||||
real_scores, _, _ = critic(real_data, real_graph, stage)
|
real_scores, _, _ = critic(real_data, real_graph, stage)
|
||||||
@ -393,24 +395,16 @@ class WGANGinkaLoss:
|
|||||||
|
|
||||||
return total_loss, d_loss
|
return total_loss, d_loss
|
||||||
|
|
||||||
def calculate_similarity_one(self, map1, map2):
|
|
||||||
topo1 = build_topological_graph(map1)
|
|
||||||
topo2 = build_topological_graph(map2)
|
|
||||||
|
|
||||||
vis_sim = calculate_visual_similarity(map1, map2)
|
|
||||||
topo_sim = overall_similarity(topo1, topo2)
|
|
||||||
|
|
||||||
return vis_sim, topo_sim
|
|
||||||
|
|
||||||
def generator_loss(self, critic, stage, mask_ratio, real, fake: torch.Tensor, input) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
def generator_loss(self, critic, stage, mask_ratio, real, fake: torch.Tensor, input) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
""" 生成器损失函数 """
|
""" 生成器损失函数 """
|
||||||
fake_graph = batch_convert_soft_map_to_graph(fake)
|
probs_fake = F.softmax(fake, dim=1)
|
||||||
|
fake_graph = batch_convert_soft_map_to_graph(probs_fake)
|
||||||
|
|
||||||
fake_scores, _, _ = critic(fake, fake_graph, stage)
|
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
|
||||||
minamo_loss = -torch.mean(fake_scores)
|
minamo_loss = -torch.mean(fake_scores)
|
||||||
ce_loss = F.l1_loss(fake, real)
|
ce_loss = F.cross_entropy(fake, real)
|
||||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
|
immutable_loss = immutable_penalty_loss(probs_fake, F.softmax(input, dim=1), STAGE_ALLOWED[stage])
|
||||||
constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
|
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
|
||||||
|
|
||||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||||
|
|
||||||
@ -419,12 +413,12 @@ class WGANGinkaLoss:
|
|||||||
ce_loss * self.weight[1] * (1 - mask_ratio), # 蒙版越大,交叉熵损失权重越小
|
ce_loss * self.weight[1] * (1 - mask_ratio), # 蒙版越大,交叉熵损失权重越小
|
||||||
immutable_loss * self.weight[2],
|
immutable_loss * self.weight[2],
|
||||||
constraint_loss * self.weight[3],
|
constraint_loss * self.weight[3],
|
||||||
-js_divergence(fake_a, fake_b) * self.weight[5],
|
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
|
||||||
]
|
]
|
||||||
|
|
||||||
if stage == 1:
|
if stage == 1:
|
||||||
# 第一个阶段检查入口存在性
|
# 第一个阶段检查入口存在性
|
||||||
entrance_loss = entrance_constraint_loss(fake)
|
entrance_loss = entrance_constraint_loss(probs_fake)
|
||||||
losses.append(entrance_loss * self.weight[4])
|
losses.append(entrance_loss * self.weight[4])
|
||||||
|
|
||||||
# print(-js_divergence(fake_a, fake_b).item())
|
# print(-js_divergence(fake_a, fake_b).item())
|
||||||
@ -432,12 +426,36 @@ class WGANGinkaLoss:
|
|||||||
return sum(losses), minamo_loss, ce_loss, immutable_loss
|
return sum(losses), minamo_loss, ce_loss, immutable_loss
|
||||||
|
|
||||||
def generator_loss_total(self, critic, stage, fake) -> torch.Tensor:
|
def generator_loss_total(self, critic, stage, fake) -> torch.Tensor:
|
||||||
fake_graph = batch_convert_soft_map_to_graph(fake)
|
probs_fake = F.softmax(fake, dim=1)
|
||||||
|
fake_graph = batch_convert_soft_map_to_graph(probs_fake)
|
||||||
|
|
||||||
fake_scores, _, _ = critic(fake, fake_graph, stage)
|
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
|
||||||
minamo_loss = -torch.mean(fake_scores)
|
minamo_loss = -torch.mean(fake_scores)
|
||||||
immutable_loss = immutable_penalty_loss(fake, fake, STAGE_ALLOWED[stage])
|
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
|
||||||
constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
|
|
||||||
|
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||||
|
|
||||||
|
losses = [
|
||||||
|
minamo_loss * self.weight[0],
|
||||||
|
constraint_loss * self.weight[3],
|
||||||
|
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
|
||||||
|
]
|
||||||
|
|
||||||
|
if stage == 1:
|
||||||
|
# 第一个阶段检查入口存在性
|
||||||
|
entrance_loss = entrance_constraint_loss(probs_fake)
|
||||||
|
losses.append(entrance_loss * self.weight[4])
|
||||||
|
|
||||||
|
return sum(losses)
|
||||||
|
|
||||||
|
def generator_loss_total_with_input(self, critic, stage, fake, input) -> torch.Tensor:
|
||||||
|
probs_fake = F.softmax(fake, dim=1)
|
||||||
|
fake_graph = batch_convert_soft_map_to_graph(probs_fake)
|
||||||
|
|
||||||
|
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
|
||||||
|
minamo_loss = -torch.mean(fake_scores)
|
||||||
|
immutable_loss = immutable_penalty_loss(probs_fake, F.softmax(input, dim=1), STAGE_ALLOWED[stage])
|
||||||
|
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
|
||||||
|
|
||||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||||
|
|
||||||
@ -445,12 +463,12 @@ class WGANGinkaLoss:
|
|||||||
minamo_loss * self.weight[0],
|
minamo_loss * self.weight[0],
|
||||||
immutable_loss * self.weight[2],
|
immutable_loss * self.weight[2],
|
||||||
constraint_loss * self.weight[3],
|
constraint_loss * self.weight[3],
|
||||||
-js_divergence(fake_a, fake_b) * self.weight[5],
|
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
|
||||||
]
|
]
|
||||||
|
|
||||||
if stage == 1:
|
if stage == 1:
|
||||||
# 第一个阶段检查入口存在性
|
# 第一个阶段检查入口存在性
|
||||||
entrance_loss = entrance_constraint_loss(fake)
|
entrance_loss = entrance_constraint_loss(probs_fake)
|
||||||
losses.append(entrance_loss * self.weight[4])
|
losses.append(entrance_loss * self.weight[4])
|
||||||
|
|
||||||
return sum(losses)
|
return sum(losses)
|
||||||
|
|||||||
@ -27,7 +27,7 @@ class GinkaModel(nn.Module):
|
|||||||
x = self.input(x)
|
x = self.input(x)
|
||||||
x = self.unet(x)
|
x = self.unet(x)
|
||||||
x = self.output(x, stage)
|
x = self.output(x, stage)
|
||||||
return F.softmax(x, dim=1)
|
return x
|
||||||
|
|
||||||
# 检查显存占用
|
# 检查显存占用
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -201,9 +201,12 @@ def train():
|
|||||||
elif train_stage == 3 or train_stage == 4:
|
elif train_stage == 3 or train_stage == 4:
|
||||||
fake1, fake2, fake3 = gen_total(ginka, masked1, True, False)
|
fake1, fake2, fake3 = gen_total(ginka, masked1, True, False)
|
||||||
|
|
||||||
loss_g1 = criterion.generator_loss_total(minamo, 1, fake1)
|
if train_stage == 3:
|
||||||
loss_g2 = criterion.generator_loss_total(minamo, 2, fake2)
|
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, masked1)
|
||||||
loss_g3 = criterion.generator_loss_total(minamo, 3, fake3)
|
else:
|
||||||
|
loss_g1 = criterion.generator_loss_total(minamo, 1, fake1)
|
||||||
|
loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1)
|
||||||
|
loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2)
|
||||||
|
|
||||||
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0
|
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0
|
||||||
loss_g.backward()
|
loss_g.backward()
|
||||||
@ -221,7 +224,7 @@ def train():
|
|||||||
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | R: {random_ratio:.1f}"
|
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | R: {random_ratio:.1f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if avg_loss_ce < 0.1:
|
if avg_loss_ce < 0.5:
|
||||||
low_loss_epochs += 1
|
low_loss_epochs += 1
|
||||||
else:
|
else:
|
||||||
low_loss_epochs = 0
|
low_loss_epochs = 0
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user