fix: 入口相关的损失函数

This commit is contained in:
unanmed 2025-03-20 22:47:33 +08:00
parent 171dcf60f1
commit 68b83c6339

View File

@ -47,14 +47,13 @@ def outer_border_constraint_loss(pred: torch.Tensor, allowed_classes=[1, 11], pe
return penalty_scale * loss
def _create_distance_kernel(size):
"""生成带距离权重的卷积核"""
kernel = torch.zeros(2*size-1, 2*size-1)
center = size-1
for i in range(2*size-1):
for j in range(2*size-1):
dist = math.sqrt((i-center)**2 + (j-center)**2)
kernel[i,j] = 1 / (1 + dist) # 距离越近权重越高
return kernel.view(1,1,2*size-1,2*size-1)
"""生成一个环状衰减核"""
y, x = torch.meshgrid(torch.arange(size), torch.arange(size), indexing='ij')
center = size // 2
dist = torch.sqrt((x - center)**2 + (y - center)**2)
kernel = torch.exp(-dist / (size / 2)) # 高斯衰减
kernel /= kernel.sum() # 归一化
return kernel.unsqueeze(0).unsqueeze(0) # [1,1,H,W]
def entrance_constraint_loss(
pred: torch.Tensor,
@ -102,9 +101,6 @@ def entrance_constraint_loss(
# 平滑惩罚函数S形曲线
spacing_loss = torch.sigmoid(10*(density_map - 0.5)).mean() # 密度>0.5时快速上升
# print(entrance_probs)
print(presence_loss.item(), (density_map).mean().item(), center_weight.mean().item())
###########################
# 区域加权综合损失
@ -274,11 +270,6 @@ def entrance_spatial_constraint(
##########################################
# 2. 边缘区域约束(只能出现箭头)
##########################################
# 提取边缘区域的箭头概率 [B, N_edge_pixels]
edge_arrow_probs = pred_probs[:, arrow_class][:, edge_mask]
# 边缘应最大化箭头概率最小化1 - arrow_prob
edge_arrow_loss = (1 - edge_arrow_probs).mean()
# 抑制边缘出现楼梯的概率 [B, N_edge_pixels]
edge_stair_probs = pred_probs[:, stair_class][:, edge_mask]
@ -287,11 +278,6 @@ def entrance_spatial_constraint(
##########################################
# 3. 中间区域约束(只能出现楼梯)
##########################################
# 提取中间区域的楼梯概率 [B, N_center_pixels]
center_stair_probs = pred_probs[:, stair_class][:, center_mask]
# 中间应最大化楼梯概率最小化1 - stair_prob
center_stair_loss = (1 - center_stair_probs).mean()
# 抑制中间出现箭头的概率 [B, N_center_pixels]
center_arrow_probs = pred_probs[:, arrow_class][:, center_mask]
@ -301,8 +287,8 @@ def entrance_spatial_constraint(
# 4. 综合损失
##########################################
total_loss = (
lambda_arrow * (edge_arrow_loss + edge_stair_penalty) +
lambda_stair * (center_stair_loss + center_arrow_penalty)
lambda_arrow * edge_stair_penalty +
lambda_stair * center_arrow_penalty
)
return total_loss