From c00a7dc5c199103ce89a4ab18cbe84ddebed8f58 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Mon, 6 Apr 2026 19:28:33 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E8=BE=93=E5=87=BA=E7=BB=B4=E5=BA=A6?= =?UTF-8?q?=E4=B8=8D=E6=AD=A3=E7=A1=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/heatmap/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ginka/heatmap/model.py b/ginka/heatmap/model.py index a9a0f88..9e29e5f 100644 --- a/ginka/heatmap/model.py +++ b/ginka/heatmap/model.py @@ -31,8 +31,8 @@ class GinkaHeatmapModel(nn.Module): hidden = hidden.view(B, C, H * W).permute(0, 2, 1) # [B, H * W, d_model] hidden = hidden + self.pos_embedding hidden = self.transformer(hidden) # [B, H * W, d_model] - output = self.output_fc(hidden) - return output.view(B, self.heatmap_dim, H, W) + output = self.output_fc(hidden) # [B, H * W, heatmap_dim] + return output.view(B, H, W, self.heatmap_dim).permute(0, 3, 1, 2) if __name__ == "__main__": device = torch.device("cpu")