diff --git a/ginka/model/loss.py b/ginka/model/loss.py index 2a5ab09..f196d7b 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -47,14 +47,13 @@ def outer_border_constraint_loss(pred: torch.Tensor, allowed_classes=[1, 11], pe return penalty_scale * loss def _create_distance_kernel(size): - """生成带距离权重的卷积核""" - kernel = torch.zeros(2*size-1, 2*size-1) - center = size-1 - for i in range(2*size-1): - for j in range(2*size-1): - dist = math.sqrt((i-center)**2 + (j-center)**2) - kernel[i,j] = 1 / (1 + dist) # 距离越近权重越高 - return kernel.view(1,1,2*size-1,2*size-1) + """生成一个环状衰减核""" + y, x = torch.meshgrid(torch.arange(size), torch.arange(size), indexing='ij') + center = size // 2 + dist = torch.sqrt((x - center)**2 + (y - center)**2) + kernel = torch.exp(-dist / (size / 2)) # 高斯衰减 + kernel /= kernel.sum() # 归一化 + return kernel.unsqueeze(0).unsqueeze(0) # [1,1,H,W] def entrance_constraint_loss( pred: torch.Tensor, @@ -102,9 +101,6 @@ def entrance_constraint_loss( # 平滑惩罚函数:S形曲线 spacing_loss = torch.sigmoid(10*(density_map - 0.5)).mean() # 密度>0.5时快速上升 - - # print(entrance_probs) - print(presence_loss.item(), (density_map).mean().item(), center_weight.mean().item()) ########################### # 区域加权综合损失 @@ -274,11 +270,6 @@ def entrance_spatial_constraint( ########################################## # 2. 边缘区域约束(只能出现箭头) ########################################## - # 提取边缘区域的箭头概率 [B, N_edge_pixels] - edge_arrow_probs = pred_probs[:, arrow_class][:, edge_mask] - - # 边缘应最大化箭头概率(最小化1 - arrow_prob) - edge_arrow_loss = (1 - edge_arrow_probs).mean() # 抑制边缘出现楼梯的概率 [B, N_edge_pixels] edge_stair_probs = pred_probs[:, stair_class][:, edge_mask] @@ -287,11 +278,6 @@ def entrance_spatial_constraint( ########################################## # 3. 中间区域约束(只能出现楼梯) ########################################## - # 提取中间区域的楼梯概率 [B, N_center_pixels] - center_stair_probs = pred_probs[:, stair_class][:, center_mask] - - # 中间应最大化楼梯概率(最小化1 - stair_prob) - center_stair_loss = (1 - center_stair_probs).mean() # 抑制中间出现箭头的概率 [B, N_center_pixels] center_arrow_probs = pred_probs[:, arrow_class][:, center_mask] @@ -301,8 +287,8 @@ def entrance_spatial_constraint( # 4. 综合损失 ########################################## total_loss = ( - lambda_arrow * (edge_arrow_loss + edge_stair_penalty) + - lambda_stair * (center_stair_loss + center_arrow_penalty) + lambda_arrow * edge_stair_penalty + + lambda_stair * center_arrow_penalty ) return total_loss