mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-22 19:31:12 +08:00
fix: focal loss
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
3ed3ad8238
commit
f0025df1ec
@ -71,10 +71,11 @@ def masked_focal(
|
|||||||
corrected = target.clone()
|
corrected = target.clone()
|
||||||
corrected[~in_set] = 0
|
corrected[~in_set] = 0
|
||||||
|
|
||||||
# 逆频类别权重:batch 内频率越高,权重越小
|
# 逆频类别权重:用原始 target 统计频率,避免 corrected 中人工填 0 膨胀
|
||||||
|
# count[0],导致 weight[0] 趋近于 0、非专属位置损失被消除的问题
|
||||||
class_weight = None
|
class_weight = None
|
||||||
if balance:
|
if balance:
|
||||||
flat = corrected.view(-1) # [B*S]
|
flat = target.view(-1) # [B*S] 原始标签
|
||||||
counts = torch.bincount(flat, minlength=C).float() # [C]
|
counts = torch.bincount(flat, minlength=C).float() # [C]
|
||||||
class_weight = flat.numel() / (counts.clamp(min=1.0) * C)
|
class_weight = flat.numel() / (counts.clamp(min=1.0) * C)
|
||||||
class_weight[counts == 0] = 0.0 # 未出现类别不参与
|
class_weight[counts == 0] = 0.0 # 未出现类别不参与
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user