mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
fix: 输出维度不正确
This commit is contained in:
parent
39b9b110d3
commit
c00a7dc5c1
@ -31,8 +31,8 @@ class GinkaHeatmapModel(nn.Module):
|
|||||||
hidden = hidden.view(B, C, H * W).permute(0, 2, 1) # [B, H * W, d_model]
|
hidden = hidden.view(B, C, H * W).permute(0, 2, 1) # [B, H * W, d_model]
|
||||||
hidden = hidden + self.pos_embedding
|
hidden = hidden + self.pos_embedding
|
||||||
hidden = self.transformer(hidden) # [B, H * W, d_model]
|
hidden = self.transformer(hidden) # [B, H * W, d_model]
|
||||||
output = self.output_fc(hidden)
|
output = self.output_fc(hidden) # [B, H * W, heatmap_dim]
|
||||||
return output.view(B, self.heatmap_dim, H, W)
|
return output.view(B, H, W, self.heatmap_dim).permute(0, 3, 1, 2)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user