mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 优化 focal loss
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
b9032d94c8
commit
3ed3ad8238
@ -39,26 +39,31 @@ def masked_focal(
|
|||||||
target: torch.Tensor,
|
target: torch.Tensor,
|
||||||
tile_set: set,
|
tile_set: set,
|
||||||
gamma: float = 2.0,
|
gamma: float = 2.0,
|
||||||
|
balance: bool = True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
通道专属 Focal Loss:tile_set 内的位置以真实 tile ID 为目标,
|
通道专属 Focal Loss + 逆频类别权重。
|
||||||
tile_set 外的位置以 0(空地)为目标,全部位置均参与损失计算。
|
|
||||||
|
|
||||||
这样模型不仅要学会"这里是什么 tile",还要学会"这里不应该是本通道的 tile",
|
tile_set 内的位置以真实 tile ID 为目标,tile_set 外的位置以 0(空地)为目标,
|
||||||
避免解码器在所有位置都输出专属类别来规避损失。
|
全部位置均参与损失计算。
|
||||||
|
|
||||||
|
balance=True 时,从 batch 内 corrected 标签的频率自动计算逆频权重,
|
||||||
|
消除空地(0)因被大量 non-tile-set 位置填充而主导梯度的问题。
|
||||||
|
权重公式:w[c] = total / (count[c] * C),与 sklearn 'balanced' 一致。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logits: [B, H*W, num_classes] 解码头输出(未经 softmax)
|
logits: [B, H*W, num_classes] 解码头输出(未经 softmax)
|
||||||
target: [B, H*W] 完整地图 ground truth(整数 tile ID)
|
target: [B, H*W] 完整地图 ground truth(整数 tile ID)
|
||||||
tile_set: set of int 本通道专属 tile 集合
|
tile_set: set of int 本通道专属 tile 集合
|
||||||
gamma: Focal Loss 聚焦参数
|
gamma: Focal Loss 聚焦参数
|
||||||
|
balance: 是否开启逆频类别权重
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
scalar tensor 通道专属 Focal Loss(均值)
|
scalar tensor 通道专属加权 Focal Loss(均值)
|
||||||
"""
|
"""
|
||||||
B, S, C = logits.shape
|
B, S, C = logits.shape
|
||||||
|
|
||||||
# 非专属 tile 位置目标替换为 0(空地),专属 tile 位置保持原始标签
|
# 非专属 tile 位置目标替换为 0(空地)
|
||||||
in_set = torch.zeros(B, S, dtype=torch.bool, device=logits.device)
|
in_set = torch.zeros(B, S, dtype=torch.bool, device=logits.device)
|
||||||
for t in tile_set:
|
for t in tile_set:
|
||||||
in_set |= (target == t)
|
in_set |= (target == t)
|
||||||
@ -66,10 +71,18 @@ def masked_focal(
|
|||||||
corrected = target.clone()
|
corrected = target.clone()
|
||||||
corrected[~in_set] = 0
|
corrected[~in_set] = 0
|
||||||
|
|
||||||
# Focal Loss,全部位置参与计算
|
# 逆频类别权重:batch 内频率越高,权重越小
|
||||||
|
class_weight = None
|
||||||
|
if balance:
|
||||||
|
flat = corrected.view(-1) # [B*S]
|
||||||
|
counts = torch.bincount(flat, minlength=C).float() # [C]
|
||||||
|
class_weight = flat.numel() / (counts.clamp(min=1.0) * C)
|
||||||
|
class_weight[counts == 0] = 0.0 # 未出现类别不参与
|
||||||
|
|
||||||
ce = F.cross_entropy(
|
ce = F.cross_entropy(
|
||||||
logits.view(-1, C),
|
logits.view(-1, C),
|
||||||
corrected.view(-1),
|
corrected.view(-1),
|
||||||
|
weight=class_weight,
|
||||||
reduction='none',
|
reduction='none',
|
||||||
).view(B, S) # [B, S]
|
).view(B, S) # [B, S]
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user