mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
fix: heatmap 模型最后不应该 Sigmoid
This commit is contained in:
parent
15cf564c2e
commit
39b9b110d3
@ -17,8 +17,7 @@ class GinkaHeatmapModel(nn.Module):
|
||||
self.input = HeatmapCond(T, embed_dim=embed_dim, heatmap_dim=heatmap_dim, output_dim=d_model)
|
||||
self.transformer = MaskGIT(d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers)
|
||||
self.output_fc = nn.Sequential(
|
||||
nn.Linear(d_model, heatmap_dim),
|
||||
nn.Sigmoid()
|
||||
nn.Linear(d_model, heatmap_dim)
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor, cond: torch.Tensor, t: torch.Tensor):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user