mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
94 lines
3.1 KiB
Python
94 lines
3.1 KiB
Python
import torch
|
||
import torch.nn.functional as F
|
||
import numpy as np
|
||
|
||
def print_memory(device, tag=""):
|
||
if torch.cuda.is_available():
|
||
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated(device) / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated(device) / 1024**2:.2f} MB")
|
||
else:
|
||
print("当前设备不支持 cuda.")
|
||
|
||
def nms_sampling(noise: np.ndarray, k: int, radius=2):
|
||
# noise: [H, W]
|
||
noise = noise.copy()
|
||
points = []
|
||
|
||
for _ in range(k):
|
||
idx = np.argmax(noise)
|
||
x, y = np.unravel_index(idx, noise.shape)
|
||
|
||
points.append((x, y))
|
||
|
||
# 抑制周围
|
||
x0 = max(0, x - radius)
|
||
x1 = min(noise.shape[0], x + radius + 1)
|
||
y0 = max(0, y - radius)
|
||
y1 = min(noise.shape[1], y + radius + 1)
|
||
|
||
noise[x0:x1, y0:y1] = -np.inf
|
||
|
||
result = np.zeros_like(noise)
|
||
for x, y in points:
|
||
result[y, x] = 1
|
||
|
||
return result
|
||
|
||
|
||
def masked_focal(
|
||
logits: torch.Tensor,
|
||
target: torch.Tensor,
|
||
tile_set: set,
|
||
gamma: float = 2.0,
|
||
balance: bool = True,
|
||
) -> torch.Tensor:
|
||
"""
|
||
通道专属 Focal Loss + 逆频类别权重。
|
||
|
||
tile_set 内的位置以真实 tile ID 为目标,tile_set 外的位置以 0(空地)为目标,
|
||
全部位置均参与损失计算。
|
||
|
||
balance=True 时,从 batch 内 corrected 标签的频率自动计算逆频权重,
|
||
消除空地(0)因被大量 non-tile-set 位置填充而主导梯度的问题。
|
||
权重公式:w[c] = total / (count[c] * C),与 sklearn 'balanced' 一致。
|
||
|
||
Args:
|
||
logits: [B, H*W, num_classes] 解码头输出(未经 softmax)
|
||
target: [B, H*W] 完整地图 ground truth(整数 tile ID)
|
||
tile_set: set of int 本通道专属 tile 集合
|
||
gamma: Focal Loss 聚焦参数
|
||
balance: 是否开启逆频类别权重
|
||
|
||
Returns:
|
||
scalar tensor 通道专属加权 Focal Loss(均值)
|
||
"""
|
||
B, S, C = logits.shape
|
||
|
||
# 非专属 tile 位置目标替换为 0(空地)
|
||
in_set = torch.zeros(B, S, dtype=torch.bool, device=logits.device)
|
||
for t in tile_set:
|
||
in_set |= (target == t)
|
||
|
||
corrected = target.clone()
|
||
corrected[~in_set] = 0
|
||
|
||
# 逆频类别权重:用原始 target 统计频率,避免 corrected 中人工填 0 膨胀
|
||
# count[0],导致 weight[0] 趋近于 0、非专属位置损失被消除的问题
|
||
class_weight = None
|
||
if balance:
|
||
flat = corrected.view(-1) # [B*S] 原始标签
|
||
counts = torch.bincount(flat, minlength=C).float() # [C]
|
||
class_weight = torch.sqrt(flat.numel() / (counts.clamp(min=1.0) * C))
|
||
class_weight[counts == 0] = 0.0 # 未出现类别不参与
|
||
|
||
ce = F.cross_entropy(
|
||
logits.view(-1, C),
|
||
corrected.view(-1),
|
||
weight=class_weight,
|
||
reduction='none',
|
||
).view(B, S) # [B, S]
|
||
|
||
pt = torch.exp(-ce.detach()) # 正确类预测概率,stop-gradient
|
||
fl = (1.0 - pt) ** gamma * ce
|
||
|
||
return fl.mean()
|