mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-24 05:01:41 +08:00
perf: 加强 GCN 部分
This commit is contained in:
parent
55f09fb37b
commit
7b138c66d9
@ -25,9 +25,11 @@ class GCNBlock(nn.Module):
|
|||||||
def __init__(self, in_ch, hidden_ch, out_ch, w, h):
|
def __init__(self, in_ch, hidden_ch, out_ch, w, h):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv1 = GCNConv(in_ch, hidden_ch)
|
self.conv1 = GCNConv(in_ch, hidden_ch)
|
||||||
self.conv2 = GCNConv(hidden_ch, out_ch)
|
self.conv2 = GCNConv(hidden_ch, hidden_ch)
|
||||||
|
self.conv3 = GCNConv(hidden_ch, out_ch)
|
||||||
self.norm1 = nn.LayerNorm(hidden_ch)
|
self.norm1 = nn.LayerNorm(hidden_ch)
|
||||||
self.norm2 = nn.LayerNorm(out_ch)
|
self.norm2 = nn.LayerNorm(hidden_ch)
|
||||||
|
self.norm3 = nn.LayerNorm(out_ch)
|
||||||
self.single_edge_index, _ = grid(h, w) # [2, E] for a single map
|
self.single_edge_index, _ = grid(h, w) # [2, E] for a single map
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -49,6 +51,8 @@ class GCNBlock(nn.Module):
|
|||||||
x = F.elu(self.norm1(x))
|
x = F.elu(self.norm1(x))
|
||||||
x = self.conv2(x, edge_index)
|
x = self.conv2(x, edge_index)
|
||||||
x = F.elu(self.norm2(x))
|
x = F.elu(self.norm2(x))
|
||||||
|
x = self.conv3(x, edge_index)
|
||||||
|
x = F.elu(self.norm3(x))
|
||||||
|
|
||||||
# Reshape back to [B, C, H, W]
|
# Reshape back to [B, C, H, W]
|
||||||
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
|
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
|
||||||
|
|||||||
@ -27,15 +27,15 @@ class InputUpsample(nn.Module):
|
|||||||
def __init__(self, in_ch, hidden_ch=64, out_ch=64):
|
def __init__(self, in_ch, hidden_ch=64, out_ch=64):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
nn.Conv2d(in_ch, hidden_ch, kernel_size=3, padding=1),
|
ConvFusionModule(in_ch, hidden_ch, hidden_ch, 13, 13),
|
||||||
nn.ELU(),
|
nn.ELU(),
|
||||||
|
|
||||||
nn.Upsample(scale_factor=2, mode='nearest'), # 13x13 → 26x26
|
nn.Upsample(scale_factor=2, mode='nearest'), # 13x13 → 26x26
|
||||||
nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, padding=1),
|
ConvFusionModule(hidden_ch, hidden_ch, hidden_ch, 26, 26),
|
||||||
nn.ELU(),
|
nn.ELU(),
|
||||||
|
|
||||||
nn.Upsample(size=(32, 32), mode='nearest'), # 26x26 → 32x32
|
nn.Upsample(size=(32, 32), mode='nearest'), # 26x26 → 32x32
|
||||||
nn.Conv2d(hidden_ch, out_ch, kernel_size=3, padding=1),
|
ConvFusionModule(hidden_ch, hidden_ch, out_ch, 32, 32),
|
||||||
nn.ELU(),
|
nn.ELU(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -52,11 +52,13 @@ class GinkaInput(nn.Module):
|
|||||||
self.enc2 = ConvFusionModule(out_ch, out_ch*4, out_ch, out_size[0], out_size[1])
|
self.enc2 = ConvFusionModule(out_ch, out_ch*4, out_ch, out_size[0], out_size[1])
|
||||||
self.inject1 = ConditionInjector(256, in_ch)
|
self.inject1 = ConditionInjector(256, in_ch)
|
||||||
self.inject2 = ConditionInjector(256, out_ch)
|
self.inject2 = ConditionInjector(256, out_ch)
|
||||||
|
self.inject3 = ConditionInjector(256, out_ch)
|
||||||
|
|
||||||
def forward(self, x, cond):
|
def forward(self, x, cond):
|
||||||
x = self.enc1(x)
|
x = self.enc1(x)
|
||||||
x = self.inject1(x, cond)
|
x = self.inject1(x, cond)
|
||||||
x = self.upsample(x)
|
x = self.upsample(x)
|
||||||
x = self.enc2(x)
|
|
||||||
x = self.inject2(x, cond)
|
x = self.inject2(x, cond)
|
||||||
|
x = self.enc2(x)
|
||||||
|
x = self.inject3(x, cond)
|
||||||
return x
|
return x
|
||||||
|
|||||||
@ -158,10 +158,13 @@ class GinkaBottleneck(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.transformer = GinkaTransformerEncoder(
|
self.transformer = GinkaTransformerEncoder(
|
||||||
in_dim=module_ch*w*h, hidden_dim=module_ch*w*h, out_dim=module_ch*w*h,
|
in_dim=module_ch*w*h, hidden_dim=module_ch*w*h, out_dim=module_ch*w*h,
|
||||||
token_size=16, ff_dim=1024, num_layers=6
|
token_size=16, ff_dim=1024, num_layers=4
|
||||||
)
|
)
|
||||||
self.gcn = GCNBlock(module_ch, module_ch*2, module_ch, 4, 4)
|
self.gcn = GCNBlock(module_ch, module_ch*2, module_ch, 4, 4)
|
||||||
self.fusion = nn.Conv2d(module_ch*3, module_ch, 1)
|
self.fusion = nn.Conv2d(module_ch*3, module_ch, 1)
|
||||||
|
# self.conv = ConvBlock(module_ch, module_ch)
|
||||||
|
# self.gcn = GCNBlock(module_ch, module_ch*2, module_ch, w, h)
|
||||||
|
# self.fusion = FusionModule(module_ch*2, module_ch)
|
||||||
self.inject = ConditionInjector(256, module_ch)
|
self.inject = ConditionInjector(256, module_ch)
|
||||||
|
|
||||||
def forward(self, x, cond):
|
def forward(self, x, cond):
|
||||||
|
|||||||
@ -350,8 +350,8 @@ def train():
|
|||||||
else:
|
else:
|
||||||
g_steps = 1
|
g_steps = 1
|
||||||
|
|
||||||
if avg_loss_ginka > 0:
|
if avg_loss_ginka > 0 and epoch > 20 and not args.resume:
|
||||||
g_steps += int(max(avg_loss_ginka * 5, 0))
|
g_steps += int(min(avg_loss_ginka * 5, 50))
|
||||||
|
|
||||||
if avg_loss_minamo > 0:
|
if avg_loss_minamo > 0:
|
||||||
c_steps = int(min(5 + avg_loss_minamo * 5, 15))
|
c_steps = int(min(5 + avg_loss_minamo * 5, 15))
|
||||||
|
|||||||
@ -35,16 +35,10 @@ def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data:
|
|||||||
torch.stack([edge_dst, edge_src], dim=0) # 反向连接
|
torch.stack([edge_dst, edge_src], dim=0) # 反向连接
|
||||||
], dim=1).to(device, dtype=torch.long)
|
], dim=1).to(device, dtype=torch.long)
|
||||||
|
|
||||||
# 3. 计算可导的边权重
|
# 3. 计算边特征
|
||||||
wall_class_idx = 1 # 假设类别 1 是墙
|
|
||||||
src_probs = torch.sigmoid(-map_probs[wall_class_idx].flatten()[edge_src])
|
|
||||||
dst_probs = torch.sigmoid(-map_probs[wall_class_idx].flatten()[edge_dst])
|
|
||||||
edge_mask = torch.nn.functional.softplus(src_probs * dst_probs).unsqueeze(1) # [E, 1]
|
|
||||||
|
|
||||||
# 4. 计算边特征
|
|
||||||
src_feat = map_probs[:, edge_src // W, edge_src % W].T # [E, C]
|
src_feat = map_probs[:, edge_src // W, edge_src % W].T # [E, C]
|
||||||
dst_feat = map_probs[:, edge_dst // W, edge_dst % W].T # [E, C]
|
dst_feat = map_probs[:, edge_dst // W, edge_dst % W].T # [E, C]
|
||||||
edge_attr = (src_feat + dst_feat) / 2 * edge_mask # [E, C]
|
edge_attr = (src_feat + dst_feat) / 2 # [E, C]
|
||||||
|
|
||||||
edge_index, edge_attr = add_self_loops(edge_index, edge_attr)
|
edge_index, edge_attr = add_self_loops(edge_index, edge_attr)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user