fix: heatmap 模型最后不应该 Sigmoid

This commit is contained in:
unanmed 2026-04-05 22:18:30 +08:00
parent 15cf564c2e
commit 39b9b110d3

View File

@ -17,8 +17,7 @@ class GinkaHeatmapModel(nn.Module):
self.input = HeatmapCond(T, embed_dim=embed_dim, heatmap_dim=heatmap_dim, output_dim=d_model)
self.transformer = MaskGIT(d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers)
self.output_fc = nn.Sequential(
nn.Linear(d_model, heatmap_dim),
nn.Sigmoid()
nn.Linear(d_model, heatmap_dim)
)
def forward(self, input: torch.Tensor, cond: torch.Tensor, t: torch.Tensor):