feat: 优化图块数量损失函数

This commit is contained in:
unanmed 2025-03-26 21:25:09 +08:00
parent a5510ef211
commit 8f892fc7f4
2 changed files with 5 additions and 5 deletions

View File

@ -30,13 +30,13 @@ class GinkaDataset(Dataset):
item = self.data[idx] item = self.data[idx]
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
target = random_smooth_onehot(target).to(self.device) target_smooth = random_smooth_onehot(target).to(self.device)
graph = differentiable_convert_to_data(target).to(self.device) graph = differentiable_convert_to_data(target_smooth).to(self.device)
vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph) vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph)
return { return {
"target_vision_feat": vision_feat, "target_vision_feat": vision_feat,
"target_topo_feat": topo_feat, "target_topo_feat": topo_feat,
"target": target "target": target,
} }

View File

@ -116,8 +116,8 @@ def adaptive_count_loss(
target_map: torch.Tensor, target_map: torch.Tensor,
class_list: list = list(range(32)), class_list: list = list(range(32)),
margin_ratio: float = 0.2, margin_ratio: float = 0.2,
zero_margin_scale: float = 0.3, zero_margin_scale: float = 0.2,
eps: float = 1e-6 eps: float = 1e-3
) -> torch.Tensor: ) -> torch.Tensor:
""" """
自适应图块数量约束损失函数 自适应图块数量约束损失函数