feat: 优化 focal loss

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
unanmed 2026-05-06 21:55:34 +08:00
parent b9032d94c8
commit 3ed3ad8238

View File

@ -39,26 +39,31 @@ def masked_focal(
target: torch.Tensor, target: torch.Tensor,
tile_set: set, tile_set: set,
gamma: float = 2.0, gamma: float = 2.0,
balance: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
通道专属 Focal Losstile_set 内的位置以真实 tile ID 为目标 通道专属 Focal Loss + 逆频类别权重
tile_set 外的位置以 0空地为目标全部位置均参与损失计算
这样模型不仅要学会"这里是什么 tile"还要学会"这里不应该是本通道的 tile" tile_set 内的位置以真实 tile ID 为目标tile_set 外的位置以 0空地为目标
避免解码器在所有位置都输出专属类别来规避损失 全部位置均参与损失计算
balance=True batch corrected 标签的频率自动计算逆频权重
消除空地0因被大量 non-tile-set 位置填充而主导梯度的问题
权重公式w[c] = total / (count[c] * C) sklearn 'balanced' 一致
Args: Args:
logits: [B, H*W, num_classes] 解码头输出未经 softmax logits: [B, H*W, num_classes] 解码头输出未经 softmax
target: [B, H*W] 完整地图 ground truth整数 tile ID target: [B, H*W] 完整地图 ground truth整数 tile ID
tile_set: set of int 本通道专属 tile 集合 tile_set: set of int 本通道专属 tile 集合
gamma: Focal Loss 聚焦参数 gamma: Focal Loss 聚焦参数
balance: 是否开启逆频类别权重
Returns: Returns:
scalar tensor 通道专属 Focal Loss均值 scalar tensor 通道专属加权 Focal Loss均值
""" """
B, S, C = logits.shape B, S, C = logits.shape
# 非专属 tile 位置目标替换为 0空地,专属 tile 位置保持原始标签 # 非专属 tile 位置目标替换为 0空地
in_set = torch.zeros(B, S, dtype=torch.bool, device=logits.device) in_set = torch.zeros(B, S, dtype=torch.bool, device=logits.device)
for t in tile_set: for t in tile_set:
in_set |= (target == t) in_set |= (target == t)
@ -66,10 +71,18 @@ def masked_focal(
corrected = target.clone() corrected = target.clone()
corrected[~in_set] = 0 corrected[~in_set] = 0
# Focal Loss全部位置参与计算 # 逆频类别权重batch 内频率越高,权重越小
class_weight = None
if balance:
flat = corrected.view(-1) # [B*S]
counts = torch.bincount(flat, minlength=C).float() # [C]
class_weight = flat.numel() / (counts.clamp(min=1.0) * C)
class_weight[counts == 0] = 0.0 # 未出现类别不参与
ce = F.cross_entropy( ce = F.cross_entropy(
logits.view(-1, C), logits.view(-1, C),
corrected.view(-1), corrected.view(-1),
weight=class_weight,
reduction='none', reduction='none',
).view(B, S) # [B, S] ).view(B, S) # [B, S]