From 171dcf60f174804cf604330778e105a18b5d4c24 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Thu, 20 Mar 2025 13:41:54 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=E6=8D=9F?= =?UTF-8?q?=E5=A4=B1=E5=87=BD=E6=95=B0=E9=83=A8=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/model/loss.py | 567 ++++++++++++++++++++++--------------------- ginka/model/model.py | 4 +- ginka/train.py | 8 +- shared/graph.py | 16 +- 4 files changed, 313 insertions(+), 282 deletions(-) diff --git a/ginka/model/loss.py b/ginka/model/loss.py index 3166648..2a5ab09 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -3,300 +3,321 @@ import torch import torch.nn as nn import torch.nn.functional as F from minamo.model.model import MinamoModel -from shared.graph import convert_soft_map_to_graph +from shared.graph import batch_convert_soft_map_to_graph -def wall_border_loss(pred: torch.Tensor, allow_border=[1, 11]): - """地图最外层是否为墙""" - # 计算 softmax 概率 +def outer_border_constraint_loss(pred: torch.Tensor, allowed_classes=[1, 11], penalty_scale=1.0): + """ + 强制地图最外圈像素必须为指定类别(墙或箭头) + + 参数: + pred: 模型输出的概率分布,形状 [B, C, H, W] + allowed_classes: 允许出现在外圈的类别列表(默认[1,11]) + penalty_scale: 惩罚强度系数 + + 返回: + loss: 标量损失值 + """ B, C, H, W = pred.shape - - # 构造一个 [H, W] 的布尔 mask,选取最外圈的像素 + + # 创建外圈mask [H, W] border_mask = torch.zeros((H, W), dtype=torch.bool, device=pred.device) - border_mask[0, :] = True - border_mask[-1, :] = True - border_mask[:, 0] = True - border_mask[:, -1] = True - - # 对允许的类别求概率和(即该像素为允许类别的总概率) - allowed_prob = pred[:, allow_border, :, :].sum(dim=1) # [B, H, W] - - # 只计算边界区域的损失:对于边界上的每个像素,要求 allowed_prob 越高越好 - border_allowed_prob = allowed_prob[:, border_mask] # [B, N_border_pixels] - - # 损失为 -log(allowed_prob) - loss = 1 - border_allowed_prob.mean() - - return loss - -def internal_wall_loss(pred, wall_class=1, threshold=2.5): - """ - 针对内部区域(排除最外圈)设计的损失函数: - 当内部任意 2×2 区域的 wall 类别概率之和超过阈值时,施加惩罚。 - - 参数: - pred: 模型输出,形状 [B, C, H, W] - wall_class: 对应墙壁的类别索引(这里假设墙壁数字为1) - threshold: 2×2 区域概率之和的阈值,超过此值时施加惩罚。可根据实际情况调节。 + border_mask[0, :] = True # 第一行 + border_mask[-1, :] = True # 最后一行 + border_mask[:, 0] = True # 第一列 + border_mask[:, -1] = True # 最后一列 - 返回: - loss: 内部墙壁连续区域的平均惩罚损失 - """ - # 取出对应墙壁类别的概率图 [B, H, W] - wall_probs = pred[:, wall_class, :, :] + # 提取所有允许类别的概率和 [B, H, W] + allowed_probs = pred[:, allowed_classes, :, :].sum(dim=1) - # 排除最外圈,取内部区域 (H, W 均减去2) - interior = wall_probs[:, 1:-1, 1:-1] # [B, H-2, W-2] + # 获取外圈区域允许类别的概率 [B, N_pixels] + border_allowed = allowed_probs[:, border_mask] - # 构造一个 2×2 的卷积核,全为 1,用于检测局部连续墙壁的概率之和 - kernel = torch.ones((1, 1, 2, 2), device=pred.device) + # 计算不符合要求的概率(反向损失) + # 1 - 允许类别的概率 = 禁止类别的概率和 + border_violation = 1 - border_allowed - # 对内部区域进行卷积操作,计算每个 2×2 区域内的概率和 - # 需要将 interior 扩展一个通道维度 - conv_result = F.conv2d(interior.unsqueeze(1), kernel, stride=1, padding=0) - # conv_result 的形状为 [B, 1, H-3, W-3] + # 使用平滑的Huber损失替代直接均值 + loss = F.huber_loss( + border_violation, + torch.zeros_like(border_violation), + delta=0.1, + reduction='mean' + ) - # 对于每个 2×2 区域,如果概率和超过 threshold,则产生惩罚 - # 这里采用 ReLU 计算超出部分,确保损失为非负 - penalty = F.relu(conv_result - threshold) - - # 取平均作为损失值 - loss = penalty.mean() - return loss + return penalty_scale * loss -def entrance_loss(pred, stairs_class=10, arrow_class=11): - """ - 针对地图生成的额外约束损失: - - 保证最外圈不出现楼梯类型入口(数字10) - - 保证内部区域不出现箭头类型入口(数字11) - - 参数: - pred: 模型输出,形状 [B, C, H, W] - stairs_class: 楼梯入口对应的类别(数字10) - arrow_class: 箭头入口对应的类别(数字11) - - 返回: - loss: 针对入口出现的惩罚损失 - """ - # 先将 logits 转为概率分布 - B, C, H, W = pred.shape +def _create_distance_kernel(size): + """生成带距离权重的卷积核""" + kernel = torch.zeros(2*size-1, 2*size-1) + center = size-1 + for i in range(2*size-1): + for j in range(2*size-1): + dist = math.sqrt((i-center)**2 + (j-center)**2) + kernel[i,j] = 1 / (1 + dist) # 距离越近权重越高 + return kernel.view(1,1,2*size-1,2*size-1) - # 构造最外圈 mask:外圈为 True,其余为 False - outer_mask = torch.zeros((H, W), dtype=torch.bool, device=pred.device) - outer_mask[0, :] = True - outer_mask[-1, :] = True - outer_mask[:, 0] = True - outer_mask[:, -1] = True - - # 内部区域 mask - interior_mask = ~outer_mask # 取反 - - # 提取对应类别的概率图 - stairs_probs = pred[:, stairs_class, :, :] # 楼梯概率 [B, H, W] - arrow_probs = pred[:, arrow_class, :, :] # 箭头概率 [B, H, W] - - # 从最外圈提取楼梯概率;用 mask 索引时:张量[:, mask] 会将每个样本的外圈像素展平 - outer_stairs = stairs_probs[:, outer_mask] # [B, num_outer_pixels] - # 从内部区域提取箭头概率 - interior_arrow = arrow_probs[:, interior_mask] # [B, num_interior_pixels] - - # 损失设计:使得这些概率尽量接近 0,直接使用均值惩罚 - outer_loss = outer_stairs.mean() - interior_loss = interior_arrow.mean() - - total_loss = outer_loss + interior_loss - return total_loss - -def entrance_distance_and_presence_loss( - pred, - arrow_class=11, stairs_class=10, - arrow_min_threshold=0.5, stairs_min_threshold=0.5, - lambda_arrow_presence=1.0, lambda_stairs_presence=1.0 +def entrance_constraint_loss( + pred: torch.Tensor, + entrance_classes=[10, 11], # 假设10是楼梯,11是箭头 + min_distance=9, + presence_threshold=0.9, + lambda_presence=1.0, + lambda_spacing=0.5 ): """ - 入口损失同时考虑: - 1. 局部距离约束:防止同一类型入口过于靠近 - 2. 存在性约束:鼓励至少放置一个入口 - - 箭头入口要求局部 (9x9) 内最多只有一个入口; - 楼梯入口要求在一个窗口(地图尺寸一半)内只出现一个楼梯入口。 + 入口约束损失函数 参数: - pred: 模型输出, shape [B, C, H, W] - arrow_class: 箭头入口类别(默认 11) - stairs_class: 楼梯入口类别(默认 10) - arrow_min_threshold: 箭头入口全局最小平均概率要求(可根据任务调节) - stairs_min_threshold: 楼梯入口全局最小平均概率要求 - lambda_arrow_presence: 箭头入口存在性损失权重 - lambda_stairs_presence: 楼梯入口存在性损失权重 + pred: 模型输出的概率分布 [B, C, H, W] + entrance_classes: 入口类别列表 + min_distance: 最小间隔距离(对应卷积核尺寸) + presence_threshold: 存在性概率阈值 + lambda_presence: 存在性损失权重 + lambda_spacing: 间距约束权重 + 返回: - total_loss: 综合入口距离与存在性损失 + total_loss: 综合损失值 """ - # 将 logits 转换为概率分布 B, C, H, W = pred.shape + entrance_probs = pred[:, entrance_classes].sum(dim=1) - # 提取箭头和楼梯的概率图 - arrow_probs = pred[:, arrow_class, :, :] # [B, H, W] - stairs_probs = pred[:, stairs_class, :, :] # [B, H, W] + ########################### + # 改进的存在性约束 + ########################### + # 计算存在性损失:鼓励至少有一个高置信度入口 + max_per_sample = entrance_probs.view(B, -1).max(dim=1)[0] + presence_loss = F.relu(presence_threshold - max_per_sample).mean() - #### 局部距离约束 #### - # 箭头:构造 9x9 卷积核,半径 4 - kernel_arrow = torch.ones((1, 1, 9, 9), device=pred.device) - local_arrow_sum = F.conv2d(arrow_probs.unsqueeze(1), kernel_arrow, padding=4) - # 减去自身概率,计算多余的局部累积 - arrow_excess = local_arrow_sum - arrow_probs.unsqueeze(1) - arrow_distance_loss = F.relu(arrow_excess).mean() + ########################### + # 改进的间距约束 + ########################### + # 生成空间权重掩码(中心衰减) + y, x = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij') + center_weight = 1 - torch.sqrt(((x-W//2)/W*2)**2 + ((y-H//2)/H*2)**2) + center_weight = center_weight.clamp(0,1).to(pred.device) # [H,W] - # 楼梯:使用窗口大小为 (W//2, H//2) - kernel_size_stairs = (9, 9) - kernel_stairs = torch.ones((1, 1, kernel_size_stairs[0], kernel_size_stairs[1]), device=pred.device) - pad_stairs = ((kernel_size_stairs[0] - 1) // 2, (kernel_size_stairs[1] - 1) // 2) - local_stairs_sum = F.conv2d(stairs_probs.unsqueeze(1), kernel_stairs, padding=pad_stairs) - stairs_excess = local_stairs_sum - stairs_probs.unsqueeze(1) - stairs_distance_loss = F.relu(stairs_excess).mean() - - #### 存在性约束 #### - # 计算每个样本中箭头的最大概率 - global_arrow_max = arrow_probs.view(B, -1).max(dim=1)[0] # [B] - global_stairs_max = stairs_probs.view(B, -1).max(dim=1)[0] # [B] + # 概率密度感知的间距计算 + kernel = _create_distance_kernel(min_distance).to(pred.device) # 自定义函数生成权重核 + density_map = F.conv2d(entrance_probs.unsqueeze(1), kernel, padding=min_distance-1) - # 取 batch 平均(或者你可以对每个样本分别计算损失再求平均) - global_arrow_max = global_arrow_max.mean() - global_stairs_max = global_stairs_max.mean() - - # 如果全局均值低于预期阈值,则施加额外惩罚 - arrow_presence_loss = F.relu(arrow_min_threshold - global_arrow_max) - stairs_presence_loss = F.relu(stairs_min_threshold - global_stairs_max) + # 平滑惩罚函数:S形曲线 + spacing_loss = torch.sigmoid(10*(density_map - 0.5)).mean() # 密度>0.5时快速上升 - ap_weighted = lambda_arrow_presence * arrow_presence_loss - sp_weighted = lambda_stairs_presence * stairs_presence_loss + # print(entrance_probs) + print(presence_loss.item(), (density_map).mean().item(), center_weight.mean().item()) - # 总入口损失:局部距离约束 + 存在性约束(加权) - total_loss = arrow_distance_loss + stairs_distance_loss \ - + min(ap_weighted, sp_weighted) + ########################### + # 区域加权综合损失 + ########################### + total_loss = ( + lambda_presence * presence_loss + + lambda_spacing * (spacing_loss * center_weight).mean() + ) return total_loss -def monster_consecutive_loss(pred, monster_classes=[7,8,9], threshold=2.9): +def adaptive_count_loss( + pred_probs: torch.Tensor, + target_map: torch.Tensor, + class_list: list = list(range(32)), + margin_ratio: float = 0.2, + zero_margin_scale: float = 0.3, + eps: float = 1e-6 +) -> torch.Tensor: """ - 检查横向和纵向是否存在连续超过三个的怪物(类别 7,8,9)。 + 自适应图块数量约束损失函数 参数: - pred: 模型输出,形状 [B, C, H, W] - monster_classes: 待检测的怪物类别列表 - threshold: 滑动窗口内概率和的阈值,若超过则施加惩罚 - (对于连续三个像素,如果每个像素概率接近 1,则窗口和接近 3) + pred_probs: 预测概率分布 [B, C, H, W] + target_map: 真实地图 [B, C, H, W] + class_list: 需要约束的类别列表 + margin_ratio: 允许的相对误差范围(如0.2表示±20%) + zero_margin_scale: 参考数量为0时的允许余量系数(余量=scale*sqrt(H*W)) + eps: 数值稳定性常数 返回: - loss: 惩罚损失(数值越高表示连续怪物区域越严重) - """ - # 将 logits 转换为概率分布 - B, C, H, W = pred.shape - - # 得到怪物整体概率图:将类别 7,8,9 的概率相加 - monster_probs = pred[:, monster_classes, :].sum(dim=1) # [B, H, W] - - # 注意:monster_probs 越高说明该像素更有可能是怪物 - - # --- 横向检测 --- - # 构造一个 (1,3) 的卷积核,全 1 - kernel_horiz = torch.ones((1, 1, 1, 3), device=pred.device) - # 对 monster_probs 加一个 channel 维度,使形状为 [B, 1, H, W] - conv_horiz = F.conv2d(monster_probs.unsqueeze(1), kernel_horiz, padding=(0,1)) - # conv_horiz 的每个值表示相邻三个像素的怪物概率和 - - # --- 纵向检测 --- - # 构造一个 (3,1) 的卷积核,全 1 - kernel_vert = torch.ones((1, 1, 3, 1), device=pred.device) - conv_vert = F.conv2d(monster_probs.unsqueeze(1), kernel_vert, padding=(1,0)) - # conv_vert 的每个值表示垂直连续三个像素的怪物概率和 - - # 对两个方向的窗口,如果概率和超过阈值,则计算超出部分的惩罚 - penalty_horiz = F.relu(conv_horiz - threshold) - penalty_vert = F.relu(conv_vert - threshold) - - # 将两个方向的惩罚损失取平均(或者直接相加) - loss = penalty_horiz.mean() + penalty_vert.mean() - return loss - -def illegal_block_loss(pred, used_classes=12, mode='mean'): - """ - 对未使用类别(例如 12 ~ 31)的预测概率施加惩罚, - 鼓励模型输出仅集中在 0 ~ 11 上。 - - 参数: - pred: 模型输出,形状 [B, num_classes, H, W] - used_classes: 已经使用的类别数(例如 12 表示只使用 0-11) - mode: 'mean' 使用平均概率,或 'mse' 使用均方误差 - - 返回: - penalty: 标量惩罚损失 - """ - B, C, H, W = pred.shape - # 选取非法类别的概率(注意:这一步会得到非法图块在每个像素上的概率) - illegal_probs = pred[:, range(used_classes, 32), :, :] # [B, len(illegal_classes), H, W] - - # 我们可以将非法图块的概率在类别维度上求和,得到每个像素的非法激活值 - illegal_activation = illegal_probs.sum(dim=1) # [B, H, W] - - # 接下来我们计算整个图上非法激活的“数量” - # 例如,可以直接对整个 batch 内非法激活求和 - total_illegal = illegal_activation.sum() / B # 标量 - - # 计算损失值:使用负指数函数。注意如果非法激活很小,总损失接近 exp(0)=1 - loss = torch.sqrt(total_illegal).mean() - return loss - -def integrated_count_loss(probs, target, class_list=[0,1,2,3,4,5,6,7,8,9], tolerance=0.5): - """ - 对每个类别分别计算数量匹配损失,再取平均。 - - 参数: - probs: 模型输出的概率,形状 [B, num_classes, H, W] - target: 真实标签,形状 [B, H, W],类别取值在 0 ~ 使用范围-1 内 - class_list: 需要计算的类别列表 - tolerance: 每个类别允许的相对误差(例如 0.15 表示 15%) - - 返回: - loss: 对每个类别数量匹配损失取平均后的标量 + loss: 标量损失值 """ + B, C, H, W = pred_probs.shape + device = pred_probs.device total_loss = 0.0 - count = 0 - B, C, H, W = probs.shape + valid_classes = 0 + + # 预计算地图面积用于余量计算 + map_area = math.sqrt(H * W) for cls in class_list: - # 预测数量:对于当前类别,所有像素的预测概率和 - pred_count = probs[:, cls, :, :].sum() - # 真实数量:统计 target 中属于当前类别的像素数量 - true_count = (target == cls).float().sum() + # 预测数量(概率和) + pred_count = pred_probs[:, cls].sum(dim=(1,2)) # [B] + # 真实数量 + true_count = target_map[:, cls].sum(dim=(1,2)) # [B] - if true_count == 0: - # 参考地图中不包含该类别,允许最多出现 (sqrt(地图尺寸) / 2) 个单位的概率输出 - cls_loss = F.relu(pred_count - math.sqrt(H * W) / 2) - else: - # 计算相对误差 - rel_error = torch.abs(pred_count - true_count) / (true_count) - cls_loss = F.relu(rel_error - tolerance) + # 动态容差计算 + with torch.no_grad(): + # 当真实数量为0时的允许上限 + zero_mask = (true_count == 0) + dynamic_margin = torch.where( + zero_mask, + zero_margin_scale * map_area, # 允许存在少量 + margin_ratio * true_count # 相对误差范围 + ) - total_loss += cls_loss - count += 1 + # 误差计算(考虑数值稳定性) + safe_true = true_count + eps * zero_mask # 零真实值时添加微小量 + abs_error = torch.abs(pred_count - true_count) + rel_error = abs_error / safe_true - # 求平均每个类别的损失 - avg_loss = total_loss / count - return avg_loss + # 双阶段损失函数 + # 阶段一:误差在容差范围内时使用二次函数(强梯度) + # 阶段二:超出容差时转为线性(稳定训练) + loss_per_class = torch.where( + abs_error <= dynamic_margin, + (rel_error ** 2) * 0.5, # 区间内强梯度 + rel_error - (0.5 * margin_ratio) # 区间外稳定梯度 + ) + + # 零真实值特殊处理:仅惩罚超出余量部分 + loss_per_class = torch.where( + zero_mask, + F.relu(abs_error - dynamic_margin) / map_area, # 归一化处理 + loss_per_class + ) + + total_loss += loss_per_class.mean() + valid_classes += 1 + + return total_loss / valid_classes # 类别平均 + +def illegal_tile_loss( + pred_probs: torch.Tensor, + legal_classes: int = 13, + temperature: float = 0.1, + eps: float = 1e-8 +) -> torch.Tensor: + """ + 非法图块惩罚损失函数 + + 参数: + pred_probs: 模型输出的概率分布 [B, C, H, W] + legal_classes: 合法图块数量(0-based,默认0-12为合法) + temperature: 概率锐化温度系数(0.1-1.0) + eps: 数值稳定性保护 + + 返回: + loss: 标量损失值 + """ + B, C, H, W = pred_probs.shape + + # 提取非法图块概率(类别13及之后) + illegal_probs = pred_probs[:, legal_classes:, :, :] # [B, C_illegal, H, W] + + # 概率锐化(增强高概率区域的惩罚) + sharpened_probs = torch.exp(torch.log(illegal_probs + eps) / temperature) + sharpened_probs = sharpened_probs / (sharpened_probs.sum(dim=1, keepdim=True) + eps) + + # 空间敏感权重(关注高置信度非法区域) + with torch.no_grad(): + # 计算每个像素的非法概率置信度 + confidence = illegal_probs.max(dim=1)[0] # [B, H, W] + # 生成注意力权重(高置信度区域权重加倍) + spatial_weights = 1 + torch.sigmoid(10*(confidence - 0.5)) + + # 逐像素计算非法概率损失 + per_pixel_loss = torch.log(1 + illegal_probs.sum(dim=1)) # [B, H, W] + + # 加权空间损失 + weighted_loss = (per_pixel_loss * spatial_weights).mean() + + # 类别平衡因子(抑制高频非法类别) + class_balance = 1 + torch.var(illegal_probs.mean(dim=(0,2,3))) # [C_illegal] + + return weighted_loss * class_balance.mean() + +def entrance_spatial_constraint( + pred_probs: torch.Tensor, + arrow_class: int = 11, + stair_class: int = 10, + border_width: int = 1, + lambda_arrow: float = 1.0, + lambda_stair: float = 1.0 +) -> torch.Tensor: + """ + 入口空间约束损失函数 + + 参数: + pred_probs: 模型输出的概率分布 [B, C, H, W] + arrow_class: 箭头入口类别索引 + stair_class: 楼梯入口类别索引 + border_width: 边缘区域宽度(默认1表示最外圈) + lambda_arrow: 箭头约束权重 + lambda_stair: 楼梯约束权重 + + 返回: + loss: 标量损失值 + """ + B, C, H, W = pred_probs.shape + + ########################################## + # 1. 区域掩码生成 + ########################################## + # 生成边缘区域掩码 [H, W] + edge_mask = torch.zeros((H, W), dtype=torch.bool, device=pred_probs.device) + # 上下边缘 + edge_mask[:border_width, :] = True + edge_mask[-border_width:, :] = True + # 左右边缘(排除已标记的角落) + edge_mask[:, :border_width] = True + edge_mask[:, -border_width:] = True + + # 生成中间区域掩码 [H, W] + center_mask = ~edge_mask + + ########################################## + # 2. 边缘区域约束(只能出现箭头) + ########################################## + # 提取边缘区域的箭头概率 [B, N_edge_pixels] + edge_arrow_probs = pred_probs[:, arrow_class][:, edge_mask] + + # 边缘应最大化箭头概率(最小化1 - arrow_prob) + edge_arrow_loss = (1 - edge_arrow_probs).mean() + + # 抑制边缘出现楼梯的概率 [B, N_edge_pixels] + edge_stair_probs = pred_probs[:, stair_class][:, edge_mask] + edge_stair_penalty = F.relu(edge_stair_probs - 0.1).mean() # 允许10%以下 + + ########################################## + # 3. 中间区域约束(只能出现楼梯) + ########################################## + # 提取中间区域的楼梯概率 [B, N_center_pixels] + center_stair_probs = pred_probs[:, stair_class][:, center_mask] + + # 中间应最大化楼梯概率(最小化1 - stair_prob) + center_stair_loss = (1 - center_stair_probs).mean() + + # 抑制中间出现箭头的概率 [B, N_center_pixels] + center_arrow_probs = pred_probs[:, arrow_class][:, center_mask] + center_arrow_penalty = F.relu(center_arrow_probs - 0.1).mean() # 允许10%以下 + + ########################################## + # 4. 综合损失 + ########################################## + total_loss = ( + lambda_arrow * (edge_arrow_loss + edge_stair_penalty) + + lambda_stair * (center_stair_loss + center_arrow_penalty) + ) + + return total_loss class GinkaLoss(nn.Module): - def __init__(self, minamo: MinamoModel, weight=[0.35, 0.1, 0.1, 0.1, 0.1, 0.05, 0.1, 0.1]): + def __init__(self, minamo: MinamoModel, weight=[0.5, 0.15, 0.15, 0.1, 0.1]): """Ginka Model 损失函数部分 Args: weight (list, optional): 每一个损失函数的权重,从第 0 项开始,依次是: - 1. 拓扑图损失 + 1. Minamo 相似度损失 2. 外圈墙壁损失 - 3. 内层 2*2 墙壁损失 - 4. 要求外层只能有箭头,内层只能有楼梯的损失 - 5. 入口间距及存在性损失 - 6. 连续怪物损失 - 7. 非法图块损失 - 8. 怪物、道具、门数量损失 + 3. 入口间距及存在性损失 + 4. 怪物、道具、门数量损失 + 5. 非法图块损失 """ super().__init__() self.weight = weight @@ -304,41 +325,37 @@ class GinkaLoss(nn.Module): def forward(self, pred, target, target_vision_feat, target_topo_feat): # 地图结构损失 - border_loss = wall_border_loss(pred) - wall_loss = internal_wall_loss(pred) - entry_loss = entrance_loss(pred) - entry_dis_loss = entrance_distance_and_presence_loss(pred, ) - enemy_loss = monster_consecutive_loss(pred) - valid_block_loss = illegal_block_loss(pred, used_classes=12, mode="mean") - count_loss = integrated_count_loss(pred, target) + border_loss = outer_border_constraint_loss(pred) + entrance_loss = entrance_constraint_loss(pred) * 0.5 + entrance_spatial_constraint(pred) * 0.5 + count_loss = adaptive_count_loss(pred, target) + illegal_loss = illegal_tile_loss(pred) # 使用 Minamo Model 计算相似度 - graph = convert_soft_map_to_graph(pred) + graph = batch_convert_soft_map_to_graph(pred) pred_vision_feat, pred_topo_feat = self.minamo(pred, graph) vision_sim = F.cosine_similarity(pred_vision_feat, target_vision_feat, dim=-1) topo_sim = F.cosine_similarity(pred_topo_feat, target_topo_feat, dim=-1) minamo_sim = 0.3 * vision_sim + 0.7 * topo_sim - minamo_loss = torch.exp(-1 * (minamo_sim - 0.8)).mean() + minamo_loss = (1.0 - minamo_sim).mean() print( minamo_loss.item(), border_loss.item(), - wall_loss.item(), - entry_loss.item(), - entry_dis_loss.item(), - enemy_loss.item(), - valid_block_loss.item(), - count_loss.item() + entrance_loss.item(), + count_loss.item(), + illegal_loss.item() ) - return ( - minamo_loss * self.weight[0] + - border_loss * self.weight[1] + - wall_loss * self.weight[2] + - entry_loss * self.weight[3] + - entry_dis_loss * self.weight[4] + - enemy_loss * self.weight[5] + - valid_block_loss * self.weight[6] + - count_loss * self.weight[7] - ) \ No newline at end of file + losses = [ + minamo_loss * self.weight[0], + border_loss * self.weight[1] * 0.1, + entrance_loss * self.weight[2], + count_loss * self.weight[3], + illegal_loss * self.weight[4] + ] + + # 梯度归一化 + scaled_losses = [loss / (loss.detach() + 1e-6) for loss in losses] + total_loss = sum(scaled_losses) + return total_loss \ No newline at end of file diff --git a/ginka/model/model.py b/ginka/model/model.py index 07ddf53..dac8576 100644 --- a/ginka/model/model.py +++ b/ginka/model/model.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from .unet import GinkaUNet class GinkaModel(nn.Module): - def __init__(self, feat_dim=256, base_ch=64, num_classes=32): + def __init__(self, feat_dim=256, base_ch=128, num_classes=32): """Ginka Model 模型定义部分 """ super().__init__() @@ -25,5 +25,5 @@ class GinkaModel(nn.Module): x = x.view(-1, self.base_ch, 32, 32) x = self.unet(x) x = F.interpolate(x, (13, 13), mode='bilinear', align_corners=False) - return F.softmax(x) + return F.softmax(x, dim=1) \ No newline at end of file diff --git a/ginka/train.py b/ginka/train.py index 3fcce09..70047a7 100644 --- a/ginka/train.py +++ b/ginka/train.py @@ -46,13 +46,13 @@ def train(): ) # 设定优化器与调度器 - optimizer = optim.AdamW(model.parameters(), lr=3e-4) + optimizer = optim.AdamW(model.parameters(), lr=1e-3) scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6) - criterion = GinkaLoss(minamo, weight=[1, 0, 0, 0, 0, 0, 0, 0]) + criterion = GinkaLoss(minamo) - model.register_full_backward_hook(grad_hook) + # model.register_full_backward_hook(grad_hook) # converter.register_full_backward_hook(grad_hook) - criterion.register_full_backward_hook(grad_hook) + # criterion.register_full_backward_hook(grad_hook) # 开始训练 for epoch in tqdm(range(epochs)): diff --git a/shared/graph.py b/shared/graph.py index 9f4fb31..17c15c6 100644 --- a/shared/graph.py +++ b/shared/graph.py @@ -1,5 +1,5 @@ import torch -from torch_geometric.data import Data +from torch_geometric.data import Data, Batch def convert_soft_map_to_graph(map_probs: torch.Tensor): """ @@ -31,3 +31,17 @@ def convert_soft_map_to_graph(map_probs: torch.Tensor): map_probs[:, edge_index[1] // W, edge_index[1] % W]) / 2 return Data(x=node_features, edge_index=edge_index, edge_attr=soft_edge_weight) + +def batch_convert_soft_map_to_graph(batch_map_probs): + """ + 处理 batch 维度,将 [B, C, H, W] 转换为批量图结构 Batch + """ + B, C, H, W = batch_map_probs.shape # 获取 batch 维度 + batch_graphs = [] + + for i in range(B): + graph = convert_soft_map_to_graph(batch_map_probs[i]) # 处理单个样本 + batch_graphs.append(graph) + + # 合并所有图为批量 Batch + return Batch.from_data_list(batch_graphs)