From f0025df1ec7332ed9e067acef5f1756ef1e3792e Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Wed, 6 May 2026 22:27:56 +0800 Subject: [PATCH] fix: focal loss Co-authored-by: Copilot --- ginka/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ginka/utils.py b/ginka/utils.py index a7845ec..6231c84 100644 --- a/ginka/utils.py +++ b/ginka/utils.py @@ -71,10 +71,11 @@ def masked_focal( corrected = target.clone() corrected[~in_set] = 0 - # 逆频类别权重:batch 内频率越高,权重越小 + # 逆频类别权重:用原始 target 统计频率,避免 corrected 中人工填 0 膨胀 + # count[0],导致 weight[0] 趋近于 0、非专属位置损失被消除的问题 class_weight = None if balance: - flat = corrected.view(-1) # [B*S] + flat = target.view(-1) # [B*S] 原始标签 counts = torch.bincount(flat, minlength=C).float() # [C] class_weight = flat.numel() / (counts.clamp(min=1.0) * C) class_weight[counts == 0] = 0.0 # 未出现类别不参与