diff --git a/ginka/heatmap/model.py b/ginka/heatmap/model.py index e199b87..a9a0f88 100644 --- a/ginka/heatmap/model.py +++ b/ginka/heatmap/model.py @@ -17,8 +17,7 @@ class GinkaHeatmapModel(nn.Module): self.input = HeatmapCond(T, embed_dim=embed_dim, heatmap_dim=heatmap_dim, output_dim=d_model) self.transformer = MaskGIT(d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers) self.output_fc = nn.Sequential( - nn.Linear(d_model, heatmap_dim), - nn.Sigmoid() + nn.Linear(d_model, heatmap_dim) ) def forward(self, input: torch.Tensor, cond: torch.Tensor, t: torch.Tensor):