ginka-generator/ginka/utils.py

94 lines
3.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn.functional as F
import numpy as np
def print_memory(device, tag=""):
if torch.cuda.is_available():
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated(device) / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated(device) / 1024**2:.2f} MB")
else:
print("当前设备不支持 cuda.")
def nms_sampling(noise: np.ndarray, k: int, radius=2):
# noise: [H, W]
noise = noise.copy()
points = []
for _ in range(k):
idx = np.argmax(noise)
x, y = np.unravel_index(idx, noise.shape)
points.append((x, y))
# 抑制周围
x0 = max(0, x - radius)
x1 = min(noise.shape[0], x + radius + 1)
y0 = max(0, y - radius)
y1 = min(noise.shape[1], y + radius + 1)
noise[x0:x1, y0:y1] = -np.inf
result = np.zeros_like(noise)
for x, y in points:
result[y, x] = 1
return result
def masked_focal(
logits: torch.Tensor,
target: torch.Tensor,
tile_set: set,
gamma: float = 2.0,
balance: bool = True,
) -> torch.Tensor:
"""
通道专属 Focal Loss + 逆频类别权重。
tile_set 内的位置以真实 tile ID 为目标tile_set 外的位置以 0空地为目标
全部位置均参与损失计算。
balance=True 时,从 batch 内 corrected 标签的频率自动计算逆频权重,
消除空地0因被大量 non-tile-set 位置填充而主导梯度的问题。
权重公式w[c] = total / (count[c] * C),与 sklearn 'balanced' 一致。
Args:
logits: [B, H*W, num_classes] 解码头输出(未经 softmax
target: [B, H*W] 完整地图 ground truth整数 tile ID
tile_set: set of int 本通道专属 tile 集合
gamma: Focal Loss 聚焦参数
balance: 是否开启逆频类别权重
Returns:
scalar tensor 通道专属加权 Focal Loss均值
"""
B, S, C = logits.shape
# 非专属 tile 位置目标替换为 0空地
in_set = torch.zeros(B, S, dtype=torch.bool, device=logits.device)
for t in tile_set:
in_set |= (target == t)
corrected = target.clone()
corrected[~in_set] = 0
# 逆频类别权重:用原始 target 统计频率,避免 corrected 中人工填 0 膨胀
# count[0],导致 weight[0] 趋近于 0、非专属位置损失被消除的问题
class_weight = None
if balance:
flat = corrected.view(-1) # [B*S] 原始标签
counts = torch.bincount(flat, minlength=C).float() # [C]
class_weight = torch.sqrt(flat.numel() / (counts.clamp(min=1.0) * C))
class_weight[counts == 0] = 0.0 # 未出现类别不参与
ce = F.cross_entropy(
logits.view(-1, C),
corrected.view(-1),
weight=class_weight,
reduction='none',
).view(B, S) # [B, S]
pt = torch.exp(-ce.detach()) # 正确类预测概率stop-gradient
fl = (1.0 - pt) ** gamma * ce
return fl.mean()