fix: 输出维度不正确

This commit is contained in:
unanmed 2026-04-06 19:28:33 +08:00
parent 39b9b110d3
commit c00a7dc5c1

View File

@ -31,8 +31,8 @@ class GinkaHeatmapModel(nn.Module):
hidden = hidden.view(B, C, H * W).permute(0, 2, 1) # [B, H * W, d_model]
hidden = hidden + self.pos_embedding
hidden = self.transformer(hidden) # [B, H * W, d_model]
output = self.output_fc(hidden)
return output.view(B, self.heatmap_dim, H, W)
output = self.output_fc(hidden) # [B, H * W, heatmap_dim]
return output.view(B, H, W, self.heatmap_dim).permute(0, 3, 1, 2)
if __name__ == "__main__":
device = torch.device("cpu")