diff --git a/ginka/dataset.py b/ginka/dataset.py index 0fa3502..add2eda 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -30,13 +30,13 @@ class GinkaDataset(Dataset): item = self.data[idx] target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] - target = random_smooth_onehot(target).to(self.device) - graph = differentiable_convert_to_data(target).to(self.device) + target_smooth = random_smooth_onehot(target).to(self.device) + graph = differentiable_convert_to_data(target_smooth).to(self.device) vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph) return { "target_vision_feat": vision_feat, "target_topo_feat": topo_feat, - "target": target + "target": target, } \ No newline at end of file diff --git a/ginka/model/loss.py b/ginka/model/loss.py index 55f3522..ea635ca 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -116,8 +116,8 @@ def adaptive_count_loss( target_map: torch.Tensor, class_list: list = list(range(32)), margin_ratio: float = 0.2, - zero_margin_scale: float = 0.3, - eps: float = 1e-6 + zero_margin_scale: float = 0.2, + eps: float = 1e-3 ) -> torch.Tensor: """ 自适应图块数量约束损失函数