mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 优化图块数量损失函数
This commit is contained in:
parent
a5510ef211
commit
8f892fc7f4
@ -30,13 +30,13 @@ class GinkaDataset(Dataset):
|
||||
item = self.data[idx]
|
||||
|
||||
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
target = random_smooth_onehot(target).to(self.device)
|
||||
graph = differentiable_convert_to_data(target).to(self.device)
|
||||
target_smooth = random_smooth_onehot(target).to(self.device)
|
||||
graph = differentiable_convert_to_data(target_smooth).to(self.device)
|
||||
vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph)
|
||||
|
||||
return {
|
||||
"target_vision_feat": vision_feat,
|
||||
"target_topo_feat": topo_feat,
|
||||
"target": target
|
||||
"target": target,
|
||||
}
|
||||
|
||||
@ -116,8 +116,8 @@ def adaptive_count_loss(
|
||||
target_map: torch.Tensor,
|
||||
class_list: list = list(range(32)),
|
||||
margin_ratio: float = 0.2,
|
||||
zero_margin_scale: float = 0.3,
|
||||
eps: float = 1e-6
|
||||
zero_margin_scale: float = 0.2,
|
||||
eps: float = 1e-3
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
自适应图块数量约束损失函数
|
||||
|
||||
Loading…
Reference in New Issue
Block a user