diff --git a/ginka/heatmap/model.py b/ginka/heatmap/model.py index a9a0f88..9e29e5f 100644 --- a/ginka/heatmap/model.py +++ b/ginka/heatmap/model.py @@ -31,8 +31,8 @@ class GinkaHeatmapModel(nn.Module): hidden = hidden.view(B, C, H * W).permute(0, 2, 1) # [B, H * W, d_model] hidden = hidden + self.pos_embedding hidden = self.transformer(hidden) # [B, H * W, d_model] - output = self.output_fc(hidden) - return output.view(B, self.heatmap_dim, H, W) + output = self.output_fc(hidden) # [B, H * W, heatmap_dim] + return output.view(B, H, W, self.heatmap_dim).permute(0, 3, 1, 2) if __name__ == "__main__": device = torch.device("cpu")