From 8f892fc7f4b596285a3ce4eeac5c5d00c8a37a36 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Wed, 26 Mar 2025 21:25:09 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=E5=9B=BE=E5=9D=97?= =?UTF-8?q?=E6=95=B0=E9=87=8F=E6=8D=9F=E5=A4=B1=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/dataset.py | 6 +++--- ginka/model/loss.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ginka/dataset.py b/ginka/dataset.py index 0fa3502..add2eda 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -30,13 +30,13 @@ class GinkaDataset(Dataset): item = self.data[idx] target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] - target = random_smooth_onehot(target).to(self.device) - graph = differentiable_convert_to_data(target).to(self.device) + target_smooth = random_smooth_onehot(target).to(self.device) + graph = differentiable_convert_to_data(target_smooth).to(self.device) vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph) return { "target_vision_feat": vision_feat, "target_topo_feat": topo_feat, - "target": target + "target": target, } \ No newline at end of file diff --git a/ginka/model/loss.py b/ginka/model/loss.py index 55f3522..ea635ca 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -116,8 +116,8 @@ def adaptive_count_loss( target_map: torch.Tensor, class_list: list = list(range(32)), margin_ratio: float = 0.2, - zero_margin_scale: float = 0.3, - eps: float = 1e-6 + zero_margin_scale: float = 0.2, + eps: float = 1e-3 ) -> torch.Tensor: """ 自适应图块数量约束损失函数