diff --git a/ginka/utils.py b/ginka/utils.py index b604f9e..114f5f6 100644 --- a/ginka/utils.py +++ b/ginka/utils.py @@ -39,39 +39,41 @@ def masked_focal( target: torch.Tensor, tile_set: set, gamma: float = 2.0, - eps: float = 1e-6, ) -> torch.Tensor: """ - 通道专属掩码 Focal Loss:仅在 tile_set 中指定的 tile 位置计算损失。 + 通道专属 Focal Loss:tile_set 内的位置以真实 tile ID 为目标, + tile_set 外的位置以 0(空地)为目标,全部位置均参与损失计算。 + + 这样模型不仅要学会"这里是什么 tile",还要学会"这里不应该是本通道的 tile", + 避免解码器在所有位置都输出专属类别来规避损失。 Args: logits: [B, H*W, num_classes] 解码头输出(未经 softmax) target: [B, H*W] 完整地图 ground truth(整数 tile ID) - tile_set: set of int 本通道专属 tile 集合,其余位置损失权重为 0 + tile_set: set of int 本通道专属 tile 集合 gamma: Focal Loss 聚焦参数 - eps: 数值稳定的分母偏置 Returns: - scalar tensor 通道专属掩码 Focal Loss + scalar tensor 通道专属 Focal Loss(均值) """ B, S, C = logits.shape - # 构造掩码:仅在专属 tile 位置为 True - mask = torch.zeros(B, S, dtype=torch.bool, device=logits.device) + # 非专属 tile 位置目标替换为 0(空地),专属 tile 位置保持原始标签 + in_set = torch.zeros(B, S, dtype=torch.bool, device=logits.device) for t in tile_set: - mask |= (target == t) + in_set |= (target == t) - if not mask.any(): - return logits.sum() * 0.0 # 保留计算图,返回零梯度 + corrected = target.clone() + corrected[~in_set] = 0 - # Focal Loss(reduction='none') + # Focal Loss,全部位置参与计算 ce = F.cross_entropy( logits.view(-1, C), - target.view(-1), + corrected.view(-1), reduction='none', ).view(B, S) # [B, S] pt = torch.exp(-ce.detach()) # 正确类预测概率,stop-gradient fl = (1.0 - pt) ** gamma * ce - return (fl * mask).sum() / (mask.sum() + eps) + return fl.mean()