From 7b138c66d9e7a1a79a1f2bf728903ef71c037042 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Thu, 8 May 2025 18:42:15 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E5=8A=A0=E5=BC=BA=20GCN=20=E9=83=A8?= =?UTF-8?q?=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/common/common.py | 8 ++++++-- ginka/generator/input.py | 10 ++++++---- ginka/generator/unet.py | 5 ++++- ginka/train_wgan.py | 4 ++-- shared/graph.py | 10 ++-------- 5 files changed, 20 insertions(+), 17 deletions(-) diff --git a/ginka/common/common.py b/ginka/common/common.py index 873a289..681ec46 100644 --- a/ginka/common/common.py +++ b/ginka/common/common.py @@ -25,9 +25,11 @@ class GCNBlock(nn.Module): def __init__(self, in_ch, hidden_ch, out_ch, w, h): super().__init__() self.conv1 = GCNConv(in_ch, hidden_ch) - self.conv2 = GCNConv(hidden_ch, out_ch) + self.conv2 = GCNConv(hidden_ch, hidden_ch) + self.conv3 = GCNConv(hidden_ch, out_ch) self.norm1 = nn.LayerNorm(hidden_ch) - self.norm2 = nn.LayerNorm(out_ch) + self.norm2 = nn.LayerNorm(hidden_ch) + self.norm3 = nn.LayerNorm(out_ch) self.single_edge_index, _ = grid(h, w) # [2, E] for a single map def forward(self, x): @@ -49,6 +51,8 @@ class GCNBlock(nn.Module): x = F.elu(self.norm1(x)) x = self.conv2(x, edge_index) x = F.elu(self.norm2(x)) + x = self.conv3(x, edge_index) + x = F.elu(self.norm3(x)) # Reshape back to [B, C, H, W] x = x.view(B, H, W, -1).permute(0, 3, 1, 2) diff --git a/ginka/generator/input.py b/ginka/generator/input.py index 3275bc6..e8339ac 100644 --- a/ginka/generator/input.py +++ b/ginka/generator/input.py @@ -27,15 +27,15 @@ class InputUpsample(nn.Module): def __init__(self, in_ch, hidden_ch=64, out_ch=64): super().__init__() self.net = nn.Sequential( - nn.Conv2d(in_ch, hidden_ch, kernel_size=3, padding=1), + ConvFusionModule(in_ch, hidden_ch, hidden_ch, 13, 13), nn.ELU(), nn.Upsample(scale_factor=2, mode='nearest'), # 13x13 → 26x26 - nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, padding=1), + ConvFusionModule(hidden_ch, hidden_ch, hidden_ch, 26, 26), nn.ELU(), nn.Upsample(size=(32, 32), mode='nearest'), # 26x26 → 32x32 - nn.Conv2d(hidden_ch, out_ch, kernel_size=3, padding=1), + ConvFusionModule(hidden_ch, hidden_ch, out_ch, 32, 32), nn.ELU(), ) @@ -52,11 +52,13 @@ class GinkaInput(nn.Module): self.enc2 = ConvFusionModule(out_ch, out_ch*4, out_ch, out_size[0], out_size[1]) self.inject1 = ConditionInjector(256, in_ch) self.inject2 = ConditionInjector(256, out_ch) + self.inject3 = ConditionInjector(256, out_ch) def forward(self, x, cond): x = self.enc1(x) x = self.inject1(x, cond) x = self.upsample(x) - x = self.enc2(x) x = self.inject2(x, cond) + x = self.enc2(x) + x = self.inject3(x, cond) return x diff --git a/ginka/generator/unet.py b/ginka/generator/unet.py index 5fd2647..a769df5 100644 --- a/ginka/generator/unet.py +++ b/ginka/generator/unet.py @@ -158,10 +158,13 @@ class GinkaBottleneck(nn.Module): super().__init__() self.transformer = GinkaTransformerEncoder( in_dim=module_ch*w*h, hidden_dim=module_ch*w*h, out_dim=module_ch*w*h, - token_size=16, ff_dim=1024, num_layers=6 + token_size=16, ff_dim=1024, num_layers=4 ) self.gcn = GCNBlock(module_ch, module_ch*2, module_ch, 4, 4) self.fusion = nn.Conv2d(module_ch*3, module_ch, 1) + # self.conv = ConvBlock(module_ch, module_ch) + # self.gcn = GCNBlock(module_ch, module_ch*2, module_ch, w, h) + # self.fusion = FusionModule(module_ch*2, module_ch) self.inject = ConditionInjector(256, module_ch) def forward(self, x, cond): diff --git a/ginka/train_wgan.py b/ginka/train_wgan.py index 2a6e375..dd4b54b 100644 --- a/ginka/train_wgan.py +++ b/ginka/train_wgan.py @@ -350,8 +350,8 @@ def train(): else: g_steps = 1 - if avg_loss_ginka > 0: - g_steps += int(max(avg_loss_ginka * 5, 0)) + if avg_loss_ginka > 0 and epoch > 20 and not args.resume: + g_steps += int(min(avg_loss_ginka * 5, 50)) if avg_loss_minamo > 0: c_steps = int(min(5 + avg_loss_minamo * 5, 15)) diff --git a/shared/graph.py b/shared/graph.py index e82ac7c..c109242 100644 --- a/shared/graph.py +++ b/shared/graph.py @@ -35,16 +35,10 @@ def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data: torch.stack([edge_dst, edge_src], dim=0) # 反向连接 ], dim=1).to(device, dtype=torch.long) - # 3. 计算可导的边权重 - wall_class_idx = 1 # 假设类别 1 是墙 - src_probs = torch.sigmoid(-map_probs[wall_class_idx].flatten()[edge_src]) - dst_probs = torch.sigmoid(-map_probs[wall_class_idx].flatten()[edge_dst]) - edge_mask = torch.nn.functional.softplus(src_probs * dst_probs).unsqueeze(1) # [E, 1] - - # 4. 计算边特征 + # 3. 计算边特征 src_feat = map_probs[:, edge_src // W, edge_src % W].T # [E, C] dst_feat = map_probs[:, edge_dst // W, edge_dst % W].T # [E, C] - edge_attr = (src_feat + dst_feat) / 2 * edge_mask # [E, C] + edge_attr = (src_feat + dst_feat) / 2 # [E, C] edge_index, edge_attr = add_self_loops(edge_index, edge_attr)