diff --git a/ginka/vae_rnn/decoder.py b/ginka/vae_rnn/decoder.py index 6a9c679..2317333 100644 --- a/ginka/vae_rnn/decoder.py +++ b/ginka/vae_rnn/decoder.py @@ -104,7 +104,11 @@ class DecoderInputFusion(nn.Module): num_layers=2 ) self.norm = nn.LayerNorm(d_model) - self.fusion = nn.Linear(d_model * 2, d_model) + self.fusion = nn.Sequential( + nn.Linear(d_model * 2, d_model), + nn.LayerNorm(d_model), + nn.GELU() + ) def forward( self, tile_embed: torch.Tensor, cond_vec: torch.Tensor,