mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-19 08:21:11 +08:00
fix: heatmap 模型最后不应该 Sigmoid
This commit is contained in:
parent
15cf564c2e
commit
39b9b110d3
@ -17,8 +17,7 @@ class GinkaHeatmapModel(nn.Module):
|
|||||||
self.input = HeatmapCond(T, embed_dim=embed_dim, heatmap_dim=heatmap_dim, output_dim=d_model)
|
self.input = HeatmapCond(T, embed_dim=embed_dim, heatmap_dim=heatmap_dim, output_dim=d_model)
|
||||||
self.transformer = MaskGIT(d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers)
|
self.transformer = MaskGIT(d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers)
|
||||||
self.output_fc = nn.Sequential(
|
self.output_fc = nn.Sequential(
|
||||||
nn.Linear(d_model, heatmap_dim),
|
nn.Linear(d_model, heatmap_dim)
|
||||||
nn.Sigmoid()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor, cond: torch.Tensor, t: torch.Tensor):
|
def forward(self, input: torch.Tensor, cond: torch.Tensor, t: torch.Tensor):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user