feat: 编码器特征融合添加归一化

This commit is contained in:
unanmed 2026-01-20 23:53:31 +08:00
parent b7fe24ee4c
commit 05b3b7c171

View File

@ -104,7 +104,11 @@ class DecoderInputFusion(nn.Module):
num_layers=2
)
self.norm = nn.LayerNorm(d_model)
self.fusion = nn.Linear(d_model * 2, d_model)
self.fusion = nn.Sequential(
nn.Linear(d_model * 2, d_model),
nn.LayerNorm(d_model),
nn.GELU()
)
def forward(
self, tile_embed: torch.Tensor, cond_vec: torch.Tensor,