mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 优化图块数量损失函数
This commit is contained in:
parent
a5510ef211
commit
8f892fc7f4
@ -30,13 +30,13 @@ class GinkaDataset(Dataset):
|
|||||||
item = self.data[idx]
|
item = self.data[idx]
|
||||||
|
|
||||||
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||||
target = random_smooth_onehot(target).to(self.device)
|
target_smooth = random_smooth_onehot(target).to(self.device)
|
||||||
graph = differentiable_convert_to_data(target).to(self.device)
|
graph = differentiable_convert_to_data(target_smooth).to(self.device)
|
||||||
vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph)
|
vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"target_vision_feat": vision_feat,
|
"target_vision_feat": vision_feat,
|
||||||
"target_topo_feat": topo_feat,
|
"target_topo_feat": topo_feat,
|
||||||
"target": target
|
"target": target,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -116,8 +116,8 @@ def adaptive_count_loss(
|
|||||||
target_map: torch.Tensor,
|
target_map: torch.Tensor,
|
||||||
class_list: list = list(range(32)),
|
class_list: list = list(range(32)),
|
||||||
margin_ratio: float = 0.2,
|
margin_ratio: float = 0.2,
|
||||||
zero_margin_scale: float = 0.3,
|
zero_margin_scale: float = 0.2,
|
||||||
eps: float = 1e-6
|
eps: float = 1e-3
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
自适应图块数量约束损失函数
|
自适应图块数量约束损失函数
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user