mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
fix: masked_focal
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
850c038be3
commit
b8f691269d
@ -39,39 +39,41 @@ def masked_focal(
|
||||
target: torch.Tensor,
|
||||
tile_set: set,
|
||||
gamma: float = 2.0,
|
||||
eps: float = 1e-6,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
通道专属掩码 Focal Loss:仅在 tile_set 中指定的 tile 位置计算损失。
|
||||
通道专属 Focal Loss:tile_set 内的位置以真实 tile ID 为目标,
|
||||
tile_set 外的位置以 0(空地)为目标,全部位置均参与损失计算。
|
||||
|
||||
这样模型不仅要学会"这里是什么 tile",还要学会"这里不应该是本通道的 tile",
|
||||
避免解码器在所有位置都输出专属类别来规避损失。
|
||||
|
||||
Args:
|
||||
logits: [B, H*W, num_classes] 解码头输出(未经 softmax)
|
||||
target: [B, H*W] 完整地图 ground truth(整数 tile ID)
|
||||
tile_set: set of int 本通道专属 tile 集合,其余位置损失权重为 0
|
||||
tile_set: set of int 本通道专属 tile 集合
|
||||
gamma: Focal Loss 聚焦参数
|
||||
eps: 数值稳定的分母偏置
|
||||
|
||||
Returns:
|
||||
scalar tensor 通道专属掩码 Focal Loss
|
||||
scalar tensor 通道专属 Focal Loss(均值)
|
||||
"""
|
||||
B, S, C = logits.shape
|
||||
|
||||
# 构造掩码:仅在专属 tile 位置为 True
|
||||
mask = torch.zeros(B, S, dtype=torch.bool, device=logits.device)
|
||||
# 非专属 tile 位置目标替换为 0(空地),专属 tile 位置保持原始标签
|
||||
in_set = torch.zeros(B, S, dtype=torch.bool, device=logits.device)
|
||||
for t in tile_set:
|
||||
mask |= (target == t)
|
||||
in_set |= (target == t)
|
||||
|
||||
if not mask.any():
|
||||
return logits.sum() * 0.0 # 保留计算图,返回零梯度
|
||||
corrected = target.clone()
|
||||
corrected[~in_set] = 0
|
||||
|
||||
# Focal Loss(reduction='none')
|
||||
# Focal Loss,全部位置参与计算
|
||||
ce = F.cross_entropy(
|
||||
logits.view(-1, C),
|
||||
target.view(-1),
|
||||
corrected.view(-1),
|
||||
reduction='none',
|
||||
).view(B, S) # [B, S]
|
||||
|
||||
pt = torch.exp(-ce.detach()) # 正确类预测概率,stop-gradient
|
||||
fl = (1.0 - pt) ** gamma * ce
|
||||
|
||||
return (fl * mask).sum() / (mask.sum() + eps)
|
||||
return fl.mean()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user