diff --git a/ginka/model/loss.py b/ginka/model/loss.py index a279627..a083750 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -5,7 +5,22 @@ import torch.nn.functional as F from minamo.model.model import MinamoModel from shared.graph import batch_convert_soft_map_to_graph -def outer_border_constraint_loss(pred: torch.Tensor, allowed_classes=[1, 11], penalty_scale=1.0): +CLASS_NUM = 32 +ILLEGAL_MAX_NUM = 12 + +def get_not_allowed(classes: list[int], include_illegal=False): + res = list() + for num in range(0, CLASS_NUM): + if not num in classes: + if num > ILLEGAL_MAX_NUM: + if include_illegal: + res.append(num) + else: + res.append(num) + + return res + +def outer_border_constraint_loss(pred: torch.Tensor, allowed_classes=[1, 11]): """ 强制地图最外圈像素必须为指定类别(墙或箭头) @@ -26,40 +41,58 @@ def outer_border_constraint_loss(pred: torch.Tensor, allowed_classes=[1, 11], pe border_mask[:, 0] = True # 第一列 border_mask[:, -1] = True # 最后一列 - # 提取所有允许类别的概率和 [B, H, W] - allowed_probs = pred[:, allowed_classes, :, :].sum(dim=1) + # 提取所有允许和不允许类别的概率和 [B, H, W] + unallowed_probs = pred[:, get_not_allowed(allowed_classes, include_illegal=True), :, :].sum(dim=1) # 获取外圈区域允许类别的概率 [B, N_pixels] - border_allowed = allowed_probs[:, border_mask] + border_unallowed = unallowed_probs[:, border_mask] - # 计算不符合要求的概率(反向损失) - # 1 - 允许类别的概率 = 禁止类别的概率和 - border_violation = 1 - border_allowed + target = torch.zeros_like(border_unallowed) + loss_unallowed = F.mse_loss(border_unallowed, target) - # 使用平滑的Huber损失替代直接均值 - loss = F.huber_loss( - border_violation, - torch.zeros_like(border_violation), - delta=0.1, - reduction='mean' - ) + return loss_unallowed + +def inner_constraint_loss(pred: torch.Tensor, allowed=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12]): + """限定内部允许出现的图块种类 + + Args: + pred (torch.Tensor): 模型输出的概率分布 [B, C, H, W] + unallowed (list, optional): 在地图中部(处最外圈)允许出现的图块种类. Defaults to [11]. + """ + B, C, H, W = pred.shape - return penalty_scale * loss + # 创建内部 mask [H, W] + mask = torch.ones((H, W), dtype=torch.bool, device=pred.device) + mask[0, :] = False # 第一行 + mask[-1, :] = False # 最后一行 + mask[:, 0] = False # 第一列 + mask[:, -1] = False # 最后一列 + + # 提取所有允许和不允许类别的概率和 [B, H, W] + unallowed_probs = pred[:, get_not_allowed(allowed, include_illegal=True), :, :].sum(dim=1) + + # 获取外圈区域允许类别的概率 [B, N_pixels] + inner_unallowed = unallowed_probs[:, mask] + + target = torch.zeros_like(inner_unallowed) + loss_unallowed = F.mse_loss(inner_unallowed, target) + + return loss_unallowed def _create_distance_kernel(size): """生成一个环状衰减核""" y, x = torch.meshgrid(torch.arange(size), torch.arange(size), indexing='ij') center = size // 2 dist = torch.sqrt((x - center)**2 + (y - center)**2) - kernel = torch.exp(-dist / (size / 2)) # 高斯衰减 + kernel = 1 / (dist + 1) kernel /= kernel.sum() # 归一化 - return kernel.unsqueeze(0).unsqueeze(0) # [1,1,H,W] + return kernel.unsqueeze(0).unsqueeze(0), 1 / kernel.sum() # [1,1,H,W] def entrance_constraint_loss( pred: torch.Tensor, entrance_classes=[10, 11], # 假设10是楼梯,11是箭头 min_distance=9, - presence_threshold=0.9, + presence_threshold=0.8, lambda_presence=1.0, lambda_spacing=0.5 ): @@ -78,29 +111,23 @@ def entrance_constraint_loss( total_loss: 综合损失值 """ B, C, H, W = pred.shape - entrance_probs = pred[:, entrance_classes].sum(dim=1) + entrance_probs = pred[:, entrance_classes, :, :].sum(dim=1) # [B, H, W] - ########################### - # 改进的存在性约束 - ########################### # 计算存在性损失:鼓励至少有一个高置信度入口 - max_per_sample = entrance_probs.view(B, -1).max(dim=1)[0] + max_per_sample = entrance_probs.view(B, -1).max(dim=1)[0] # [B, H*W] -> [B, 1] presence_loss = F.relu(presence_threshold - max_per_sample).mean() - ########################### - # 改进的间距约束 - ########################### # 生成空间权重掩码(中心衰减) y, x = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij') center_weight = 1 - torch.sqrt(((x-W//2)/W*2)**2 + ((y-H//2)/H*2)**2) center_weight = center_weight.clamp(0,1).to(pred.device) # [H,W] # 概率密度感知的间距计算 - kernel = _create_distance_kernel(min_distance).to(pred.device) # 自定义函数生成权重核 + kernel, cw = _create_distance_kernel(min_distance) # 自定义函数生成权重核 + kernel = kernel.to(pred.device) density_map = F.conv2d(entrance_probs.unsqueeze(1), kernel, padding=min_distance-1) - # 平滑惩罚函数:S形曲线 - spacing_loss = torch.sigmoid(10*(density_map - 0.5)).mean() # 密度>0.5时快速上升 + spacing_loss = density_map.mean() ########################### # 区域加权综合损失 @@ -115,16 +142,16 @@ def adaptive_count_loss( pred_probs: torch.Tensor, target_map: torch.Tensor, class_list: list = list(range(32)), - margin_ratio: float = 0.2, - zero_margin_scale: float = 0.2, - lambda_entropy: float = 0.05, - lambda_local: float = 0.1, - grid_size: int = 8, + margin_ratio: float = 0.1, # 降低margin比例以更严格 + zero_margin_scale: float = 0.1, # 减少零类别的margin + lambda_entropy: float = 0.2, # 增大熵约束权重 + lambda_local: float = 0.2, + lambda_max: float = 0, # 新增最大概率约束 + grid_size: int = 4, # 减小局部网格尺寸 eps: float = 1e-3 ) -> torch.Tensor: """ - 改进版自适应图块数量约束损失,包含局部匹配和熵约束 - + 改进版自适应图块数量约束损失,增强局部匹配和概率确定性 """ B, C, H, W = pred_probs.shape device = pred_probs.device @@ -134,57 +161,66 @@ def adaptive_count_loss( # 预计算地图面积 map_area = math.sqrt(H * W) - # 计算最小非零类别概率 - min_nonzero_prob = pred_probs[:, class_list].max(dim=1).values.mean() # 获取预测中的最小非零概率 - dynamic_zero_margin = zero_margin_scale * min_nonzero_prob * map_area # 让零类别不被填充 + # 动态调整零类别的margin:基于预测中最小的非零概率 + min_nonzero_prob = pred_probs[:, class_list].max(dim=1).values.mean() + dynamic_zero_margin = zero_margin_scale * min_nonzero_prob * map_area + # 计算每个类别的数量损失 for cls in class_list: - pred_count = pred_probs[:, cls].sum(dim=(1,2)) # 预测类别数量 - true_count = target_map[:, cls].sum(dim=(1,2)) # 真实类别数量 + pred_count = pred_probs[:, cls].sum(dim=(1,2)) # 预测数量 + true_count = target_map[:, cls].sum(dim=(1,2)) # 真实数量 zero_mask = (true_count == 0) dynamic_margin = torch.where( zero_mask, - dynamic_zero_margin, - margin_ratio * true_count + dynamic_zero_margin, + margin_ratio * true_count ) safe_true = true_count + eps * zero_mask abs_error = torch.abs(pred_count - true_count) rel_error = abs_error / safe_true + # 调整损失函数形状,远离目标时惩罚更大 loss_per_class = torch.where( abs_error <= dynamic_margin, - (rel_error ** 2) * 0.8 + 0.2 * rel_error, - rel_error - (0.5 * margin_ratio) + rel_error ** 2, # 近目标时二次损失 + (rel_error - 0.5 * margin_ratio) ** 2 # 远目标时二次增长 ) + # 零类别使用更严格的绝对误差惩罚 loss_per_class = torch.where( zero_mask, - F.relu(abs_error - dynamic_margin) / map_area, + F.relu(abs_error - dynamic_zero_margin) ** 2 / map_area, loss_per_class ) total_loss += loss_per_class.mean() valid_classes += 1 - # 平均类别损失 - total_loss /= valid_classes + total_loss /= valid_classes # 平均类别损失 - # 加入负熵约束,防止类别均匀化 + # 改进的熵约束:每个像素的熵 def entropy_loss(pred_probs): - avg_probs = pred_probs.mean(dim=(2, 3)) - entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-6), dim=1) - return entropy.mean() - + entropy_per_pixel = -torch.sum(pred_probs * torch.log(pred_probs + 1e-6), dim=1) + return entropy_per_pixel.mean() # 所有像素的平均熵 + total_loss += lambda_entropy * entropy_loss(pred_probs) - # 加入局部类别匹配 - def local_count_loss(pred_probs, target_probs, grid_size=8): - pred_local = F.avg_pool2d(pred_probs, kernel_size=grid_size, stride=grid_size) - target_local = F.avg_pool2d(target_probs, kernel_size=grid_size, stride=grid_size) - return F.mse_loss(pred_local, target_local) + # 新增最大概率约束:鼓励每个位置概率尖锐化 + max_probs = pred_probs.max(dim=1)[0] # 每个位置的最大概率 + max_loss = (1 - max_probs).mean() # 鼓励接近1 + total_loss += lambda_max * max_loss + # 改进局部损失:约束局部区域内的数量 + def local_count_loss(pred_probs, target_probs, grid_size): + grid_area = grid_size ** 2 + # 计算每个grid内的预测数量 + pred_counts = F.avg_pool2d(pred_probs, grid_size, stride=grid_size) * grid_area + target_counts = F.avg_pool2d(target_probs, grid_size, stride=grid_size) * grid_area + # 使用L1损失更鲁棒 + return F.l1_loss(pred_counts, target_counts) + total_loss += lambda_local * local_count_loss(pred_probs, target_map, grid_size) return total_loss @@ -300,16 +336,15 @@ def entrance_spatial_constraint( return total_loss class GinkaLoss(nn.Module): - def __init__(self, minamo: MinamoModel, weight=[0.5, 0.1, 0.1, 0.2, 0.1]): + def __init__(self, minamo: MinamoModel, weight=[0.5, 0.2, 0.1, 0.2]): """Ginka Model 损失函数部分 Args: weight (list, optional): 每一个损失函数的权重,从第 0 项开始,依次是: 1. Minamo 相似度损失 - 2. 外圈墙壁损失 + 2. 图块种类损失,要求内部不出现箭头,外圈只出现箭头和墙壁,不允许出现非法图块 3. 入口间距及存在性损失 4. 怪物、道具、门数量损失 - 5. 非法图块损失 """ super().__init__() self.weight = weight @@ -317,10 +352,9 @@ class GinkaLoss(nn.Module): def forward(self, pred, target, target_vision_feat, target_topo_feat): # 地图结构损失 - border_loss = outer_border_constraint_loss(pred) - entrance_loss = entrance_constraint_loss(pred) * 0.5 + entrance_spatial_constraint(pred) * 0.5 + class_loss = outer_border_constraint_loss(pred) + inner_constraint_loss(pred) + entrance_loss = entrance_constraint_loss(pred) + entrance_spatial_constraint(pred) count_loss = adaptive_count_loss(pred, target) - illegal_loss = illegal_tile_loss(pred) # 使用 Minamo Model 计算相似度 graph = batch_convert_soft_map_to_graph(pred) @@ -328,23 +362,21 @@ class GinkaLoss(nn.Module): vision_sim = F.cosine_similarity(pred_vision_feat, target_vision_feat, dim=-1) topo_sim = F.cosine_similarity(pred_topo_feat, target_topo_feat, dim=-1) - minamo_sim = 0.3 * vision_sim + 0.7 * topo_sim + minamo_sim = 0.2 * vision_sim + 0.8 * topo_sim minamo_loss = (1.0 - minamo_sim).mean() print( minamo_loss.item(), - border_loss.item(), + class_loss.item(), entrance_loss.item(), - count_loss.item(), - illegal_loss.item() + count_loss.item() ) losses = [ minamo_loss * self.weight[0], - border_loss * self.weight[1], + class_loss * self.weight[1], entrance_loss * self.weight[2], - count_loss * self.weight[3], - illegal_loss * self.weight[4] + count_loss * self.weight[3] ] # 梯度归一化 diff --git a/ginka/model/model.py b/ginka/model/model.py index 177966c..b534ff3 100644 --- a/ginka/model/model.py +++ b/ginka/model/model.py @@ -8,12 +8,12 @@ def print_memory(tag=""): print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB") class GinkaModel(nn.Module): - def __init__(self, feat_dim=1024, base_ch=64, num_classes=32): + def __init__(self, feat_dim=1024, base_ch=64, out_ch=32): """Ginka Model 模型定义部分 """ super().__init__() - self.unet = GinkaUNet(1, base_ch, num_classes, feat_dim) - self.output = GinkaOutput(num_classes, (13, 13)) + self.unet = GinkaUNet(1, base_ch, out_ch, feat_dim) + self.output = GinkaOutput(out_ch, out_ch, (13, 13)) def forward(self, x, feat): """ diff --git a/ginka/model/output.py b/ginka/model/output.py index 89989ab..7e34208 100644 --- a/ginka/model/output.py +++ b/ginka/model/output.py @@ -2,9 +2,12 @@ import torch import torch.nn as nn class GinkaOutput(nn.Module): - def __init__(self, num_classes=32, out_size=(13, 13)): + def __init__(self, out_ch=32, base_ch=64, out_size=(13, 13)): super().__init__() - self.pool = nn.AdaptiveAvgPool2d(out_size) + self.conv_down = nn.Sequential( + nn.AdaptiveMaxPool2d(out_size), + nn.Conv2d(base_ch, out_ch, 1) + ) def forward(self, x): - return self.pool(x) + return self.conv_down(x) diff --git a/ginka/train.py b/ginka/train.py index 9b0dbf7..34ef1c5 100644 --- a/ginka/train.py +++ b/ginka/train.py @@ -49,7 +49,7 @@ def train(): if args.resume: data = torch.load(args.from_state, map_location=device) - model.load_state_dict(data["model_state"]) + model.load_state_dict(data["model_state"], strict=False) if args.load_optim: optimizer.load_state_dict(data["optimizer_state"]) print("Train from loaded state.") diff --git a/minamo/model/loss.py b/minamo/model/loss.py index faefebf..6fb1719 100644 --- a/minamo/model/loss.py +++ b/minamo/model/loss.py @@ -1,7 +1,7 @@ import torch.nn as nn class MinamoLoss(nn.Module): - def __init__(self, vision_weight=0.4, topo_weight=0.6): + def __init__(self, vision_weight=0.2, topo_weight=0.8): super().__init__() self.vision_weight = vision_weight self.topo_weight = topo_weight