mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
fix: 入口相关的损失函数
This commit is contained in:
parent
171dcf60f1
commit
68b83c6339
@ -47,14 +47,13 @@ def outer_border_constraint_loss(pred: torch.Tensor, allowed_classes=[1, 11], pe
|
||||
return penalty_scale * loss
|
||||
|
||||
def _create_distance_kernel(size):
|
||||
"""生成带距离权重的卷积核"""
|
||||
kernel = torch.zeros(2*size-1, 2*size-1)
|
||||
center = size-1
|
||||
for i in range(2*size-1):
|
||||
for j in range(2*size-1):
|
||||
dist = math.sqrt((i-center)**2 + (j-center)**2)
|
||||
kernel[i,j] = 1 / (1 + dist) # 距离越近权重越高
|
||||
return kernel.view(1,1,2*size-1,2*size-1)
|
||||
"""生成一个环状衰减核"""
|
||||
y, x = torch.meshgrid(torch.arange(size), torch.arange(size), indexing='ij')
|
||||
center = size // 2
|
||||
dist = torch.sqrt((x - center)**2 + (y - center)**2)
|
||||
kernel = torch.exp(-dist / (size / 2)) # 高斯衰减
|
||||
kernel /= kernel.sum() # 归一化
|
||||
return kernel.unsqueeze(0).unsqueeze(0) # [1,1,H,W]
|
||||
|
||||
def entrance_constraint_loss(
|
||||
pred: torch.Tensor,
|
||||
@ -102,9 +101,6 @@ def entrance_constraint_loss(
|
||||
|
||||
# 平滑惩罚函数:S形曲线
|
||||
spacing_loss = torch.sigmoid(10*(density_map - 0.5)).mean() # 密度>0.5时快速上升
|
||||
|
||||
# print(entrance_probs)
|
||||
print(presence_loss.item(), (density_map).mean().item(), center_weight.mean().item())
|
||||
|
||||
###########################
|
||||
# 区域加权综合损失
|
||||
@ -274,11 +270,6 @@ def entrance_spatial_constraint(
|
||||
##########################################
|
||||
# 2. 边缘区域约束(只能出现箭头)
|
||||
##########################################
|
||||
# 提取边缘区域的箭头概率 [B, N_edge_pixels]
|
||||
edge_arrow_probs = pred_probs[:, arrow_class][:, edge_mask]
|
||||
|
||||
# 边缘应最大化箭头概率(最小化1 - arrow_prob)
|
||||
edge_arrow_loss = (1 - edge_arrow_probs).mean()
|
||||
|
||||
# 抑制边缘出现楼梯的概率 [B, N_edge_pixels]
|
||||
edge_stair_probs = pred_probs[:, stair_class][:, edge_mask]
|
||||
@ -287,11 +278,6 @@ def entrance_spatial_constraint(
|
||||
##########################################
|
||||
# 3. 中间区域约束(只能出现楼梯)
|
||||
##########################################
|
||||
# 提取中间区域的楼梯概率 [B, N_center_pixels]
|
||||
center_stair_probs = pred_probs[:, stair_class][:, center_mask]
|
||||
|
||||
# 中间应最大化楼梯概率(最小化1 - stair_prob)
|
||||
center_stair_loss = (1 - center_stair_probs).mean()
|
||||
|
||||
# 抑制中间出现箭头的概率 [B, N_center_pixels]
|
||||
center_arrow_probs = pred_probs[:, arrow_class][:, center_mask]
|
||||
@ -301,8 +287,8 @@ def entrance_spatial_constraint(
|
||||
# 4. 综合损失
|
||||
##########################################
|
||||
total_loss = (
|
||||
lambda_arrow * (edge_arrow_loss + edge_stair_penalty) +
|
||||
lambda_stair * (center_stair_loss + center_arrow_penalty)
|
||||
lambda_arrow * edge_stair_penalty +
|
||||
lambda_stair * center_arrow_penalty
|
||||
)
|
||||
|
||||
return total_loss
|
||||
|
||||
Loading…
Reference in New Issue
Block a user