feat: 优化图块数量损失函数

This commit is contained in:
unanmed 2025-03-26 21:25:09 +08:00
parent a5510ef211
commit 8f892fc7f4
2 changed files with 5 additions and 5 deletions

View File

@ -30,13 +30,13 @@ class GinkaDataset(Dataset):
item = self.data[idx]
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
target = random_smooth_onehot(target).to(self.device)
graph = differentiable_convert_to_data(target).to(self.device)
target_smooth = random_smooth_onehot(target).to(self.device)
graph = differentiable_convert_to_data(target_smooth).to(self.device)
vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph)
return {
"target_vision_feat": vision_feat,
"target_topo_feat": topo_feat,
"target": target
"target": target,
}

View File

@ -116,8 +116,8 @@ def adaptive_count_loss(
target_map: torch.Tensor,
class_list: list = list(range(32)),
margin_ratio: float = 0.2,
zero_margin_scale: float = 0.3,
eps: float = 1e-6
zero_margin_scale: float = 0.2,
eps: float = 1e-3
) -> torch.Tensor:
"""
自适应图块数量约束损失函数