diff --git a/ginka/model/loss.py b/ginka/model/loss.py index a154f9c..b2a0431 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -295,7 +295,10 @@ def interpolate_graph_features(real_graph, fake_graph, epsilon=0.5): return Data(x=x_interp, edge_index=edge_index_interp, edge_attr=edge_attr_interp) -def js_divergence(p, q, eps=1e-8): +def js_divergence(p, q, eps=1e-6, softmax=False): + if softmax: + p = F.softmax(p, dim=1) + q = F.softmax(q, dim=1) # softmax 后变成概率分布 m = 0.5 * (p + q) @@ -303,8 +306,6 @@ def js_divergence(p, q, eps=1e-8): log_p = torch.log(p + eps) log_q = torch.log(q + eps) log_m = torch.log(m + eps) - - nn.KLDivLoss kl_pm = F.kl_div(log_p, log_m, reduction='batchmean', log_target=True) # KL(p || m) kl_qm = F.kl_div(log_q, log_m, reduction='batchmean', log_target=True) # KL(q || m) @@ -335,7 +336,7 @@ def immutable_penalty_loss( return penalty class WGANGinkaLoss: - def __init__(self, lambda_gp=100, weight=[1, 0.4, 25, 0.2, 0.2, 0.02]): + def __init__(self, lambda_gp=100, weight=[1, 0.4, 25, 0.2, 0.2, 0.01]): # weight: 判别器损失,L1 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失 self.lambda_gp = lambda_gp # 梯度惩罚系数 self.weight = weight @@ -380,6 +381,7 @@ class WGANGinkaLoss: self, critic, stage: int, real_data: torch.Tensor, fake_data: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """ 判别器损失函数 """ + fake_data = F.softmax(fake_data, dim=1) real_graph = batch_convert_soft_map_to_graph(real_data) fake_graph = batch_convert_soft_map_to_graph(fake_data) real_scores, _, _ = critic(real_data, real_graph, stage) @@ -392,25 +394,17 @@ class WGANGinkaLoss: total_loss = d_loss + self.lambda_gp * grad_loss return total_loss, d_loss - - def calculate_similarity_one(self, map1, map2): - topo1 = build_topological_graph(map1) - topo2 = build_topological_graph(map2) - - vis_sim = calculate_visual_similarity(map1, map2) - topo_sim = overall_similarity(topo1, topo2) - - return vis_sim, topo_sim def generator_loss(self, critic, stage, mask_ratio, real, fake: torch.Tensor, input) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ 生成器损失函数 """ - fake_graph = batch_convert_soft_map_to_graph(fake) + probs_fake = F.softmax(fake, dim=1) + fake_graph = batch_convert_soft_map_to_graph(probs_fake) - fake_scores, _, _ = critic(fake, fake_graph, stage) + fake_scores, _, _ = critic(probs_fake, fake_graph, stage) minamo_loss = -torch.mean(fake_scores) - ce_loss = F.l1_loss(fake, real) - immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage]) - constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake) + ce_loss = F.cross_entropy(fake, real) + immutable_loss = immutable_penalty_loss(probs_fake, F.softmax(input, dim=1), STAGE_ALLOWED[stage]) + constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake) fake_a, fake_b = fake.chunk(2, dim=0) @@ -419,12 +413,12 @@ class WGANGinkaLoss: ce_loss * self.weight[1] * (1 - mask_ratio), # 蒙版越大,交叉熵损失权重越小 immutable_loss * self.weight[2], constraint_loss * self.weight[3], - -js_divergence(fake_a, fake_b) * self.weight[5], + -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5], ] if stage == 1: # 第一个阶段检查入口存在性 - entrance_loss = entrance_constraint_loss(fake) + entrance_loss = entrance_constraint_loss(probs_fake) losses.append(entrance_loss * self.weight[4]) # print(-js_divergence(fake_a, fake_b).item()) @@ -432,12 +426,36 @@ class WGANGinkaLoss: return sum(losses), minamo_loss, ce_loss, immutable_loss def generator_loss_total(self, critic, stage, fake) -> torch.Tensor: - fake_graph = batch_convert_soft_map_to_graph(fake) + probs_fake = F.softmax(fake, dim=1) + fake_graph = batch_convert_soft_map_to_graph(probs_fake) - fake_scores, _, _ = critic(fake, fake_graph, stage) + fake_scores, _, _ = critic(probs_fake, fake_graph, stage) minamo_loss = -torch.mean(fake_scores) - immutable_loss = immutable_penalty_loss(fake, fake, STAGE_ALLOWED[stage]) - constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake) + constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake) + + fake_a, fake_b = fake.chunk(2, dim=0) + + losses = [ + minamo_loss * self.weight[0], + constraint_loss * self.weight[3], + -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5], + ] + + if stage == 1: + # 第一个阶段检查入口存在性 + entrance_loss = entrance_constraint_loss(probs_fake) + losses.append(entrance_loss * self.weight[4]) + + return sum(losses) + + def generator_loss_total_with_input(self, critic, stage, fake, input) -> torch.Tensor: + probs_fake = F.softmax(fake, dim=1) + fake_graph = batch_convert_soft_map_to_graph(probs_fake) + + fake_scores, _, _ = critic(probs_fake, fake_graph, stage) + minamo_loss = -torch.mean(fake_scores) + immutable_loss = immutable_penalty_loss(probs_fake, F.softmax(input, dim=1), STAGE_ALLOWED[stage]) + constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake) fake_a, fake_b = fake.chunk(2, dim=0) @@ -445,12 +463,12 @@ class WGANGinkaLoss: minamo_loss * self.weight[0], immutable_loss * self.weight[2], constraint_loss * self.weight[3], - -js_divergence(fake_a, fake_b) * self.weight[5], + -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5], ] if stage == 1: # 第一个阶段检查入口存在性 - entrance_loss = entrance_constraint_loss(fake) + entrance_loss = entrance_constraint_loss(probs_fake) losses.append(entrance_loss * self.weight[4]) return sum(losses) diff --git a/ginka/model/model.py b/ginka/model/model.py index 9079cb5..260d2fe 100644 --- a/ginka/model/model.py +++ b/ginka/model/model.py @@ -27,7 +27,7 @@ class GinkaModel(nn.Module): x = self.input(x) x = self.unet(x) x = self.output(x, stage) - return F.softmax(x, dim=1) + return x # 检查显存占用 if __name__ == "__main__": diff --git a/ginka/train_wgan.py b/ginka/train_wgan.py index 8f0aebf..2392f34 100644 --- a/ginka/train_wgan.py +++ b/ginka/train_wgan.py @@ -201,9 +201,12 @@ def train(): elif train_stage == 3 or train_stage == 4: fake1, fake2, fake3 = gen_total(ginka, masked1, True, False) - loss_g1 = criterion.generator_loss_total(minamo, 1, fake1) - loss_g2 = criterion.generator_loss_total(minamo, 2, fake2) - loss_g3 = criterion.generator_loss_total(minamo, 3, fake3) + if train_stage == 3: + loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, masked1) + else: + loss_g1 = criterion.generator_loss_total(minamo, 1, fake1) + loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1) + loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2) loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0 loss_g.backward() @@ -221,7 +224,7 @@ def train(): f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | R: {random_ratio:.1f}" ) - if avg_loss_ce < 0.1: + if avg_loss_ce < 0.5: low_loss_epochs += 1 else: low_loss_epochs = 0