diff --git a/ginka/utils.py b/ginka/utils.py index 114f5f6..a7845ec 100644 --- a/ginka/utils.py +++ b/ginka/utils.py @@ -39,26 +39,31 @@ def masked_focal( target: torch.Tensor, tile_set: set, gamma: float = 2.0, + balance: bool = True, ) -> torch.Tensor: """ - 通道专属 Focal Loss:tile_set 内的位置以真实 tile ID 为目标, - tile_set 外的位置以 0(空地)为目标,全部位置均参与损失计算。 + 通道专属 Focal Loss + 逆频类别权重。 - 这样模型不仅要学会"这里是什么 tile",还要学会"这里不应该是本通道的 tile", - 避免解码器在所有位置都输出专属类别来规避损失。 + tile_set 内的位置以真实 tile ID 为目标,tile_set 外的位置以 0(空地)为目标, + 全部位置均参与损失计算。 + + balance=True 时,从 batch 内 corrected 标签的频率自动计算逆频权重, + 消除空地(0)因被大量 non-tile-set 位置填充而主导梯度的问题。 + 权重公式:w[c] = total / (count[c] * C),与 sklearn 'balanced' 一致。 Args: logits: [B, H*W, num_classes] 解码头输出(未经 softmax) target: [B, H*W] 完整地图 ground truth(整数 tile ID) tile_set: set of int 本通道专属 tile 集合 gamma: Focal Loss 聚焦参数 + balance: 是否开启逆频类别权重 Returns: - scalar tensor 通道专属 Focal Loss(均值) + scalar tensor 通道专属加权 Focal Loss(均值) """ B, S, C = logits.shape - # 非专属 tile 位置目标替换为 0(空地),专属 tile 位置保持原始标签 + # 非专属 tile 位置目标替换为 0(空地) in_set = torch.zeros(B, S, dtype=torch.bool, device=logits.device) for t in tile_set: in_set |= (target == t) @@ -66,10 +71,18 @@ def masked_focal( corrected = target.clone() corrected[~in_set] = 0 - # Focal Loss,全部位置参与计算 + # 逆频类别权重:batch 内频率越高,权重越小 + class_weight = None + if balance: + flat = corrected.view(-1) # [B*S] + counts = torch.bincount(flat, minlength=C).float() # [C] + class_weight = flat.numel() / (counts.clamp(min=1.0) * C) + class_weight[counts == 0] = 0.0 # 未出现类别不参与 + ce = F.cross_entropy( logits.view(-1, C), corrected.view(-1), + weight=class_weight, reduction='none', ).view(B, S) # [B, S]