refactor: 重构损失函数部分

This commit is contained in:
unanmed 2025-03-20 13:41:54 +08:00
parent d6b2b13ac8
commit 171dcf60f1
4 changed files with 313 additions and 282 deletions

View File

@ -3,300 +3,321 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from minamo.model.model import MinamoModel
from shared.graph import convert_soft_map_to_graph
from shared.graph import batch_convert_soft_map_to_graph
def wall_border_loss(pred: torch.Tensor, allow_border=[1, 11]):
"""地图最外层是否为墙"""
# 计算 softmax 概率
def outer_border_constraint_loss(pred: torch.Tensor, allowed_classes=[1, 11], penalty_scale=1.0):
"""
强制地图最外圈像素必须为指定类别墙或箭头
参数:
pred: 模型输出的概率分布形状 [B, C, H, W]
allowed_classes: 允许出现在外圈的类别列表默认[1,11]
penalty_scale: 惩罚强度系数
返回:
loss: 标量损失值
"""
B, C, H, W = pred.shape
# 构造一个 [H, W] 的布尔 mask选取最外圈的像素
# 创建外圈mask [H, W]
border_mask = torch.zeros((H, W), dtype=torch.bool, device=pred.device)
border_mask[0, :] = True
border_mask[-1, :] = True
border_mask[:, 0] = True
border_mask[:, -1] = True
# 对允许的类别求概率和(即该像素为允许类别的总概率)
allowed_prob = pred[:, allow_border, :, :].sum(dim=1) # [B, H, W]
# 只计算边界区域的损失:对于边界上的每个像素,要求 allowed_prob 越高越好
border_allowed_prob = allowed_prob[:, border_mask] # [B, N_border_pixels]
# 损失为 -log(allowed_prob)
loss = 1 - border_allowed_prob.mean()
return loss
def internal_wall_loss(pred, wall_class=1, threshold=2.5):
"""
针对内部区域排除最外圈设计的损失函数
当内部任意 2×2 区域的 wall 类别概率之和超过阈值时施加惩罚
参数:
pred: 模型输出形状 [B, C, H, W]
wall_class: 对应墙壁的类别索引这里假设墙壁数字为1
threshold: 2×2 区域概率之和的阈值超过此值时施加惩罚可根据实际情况调节
border_mask[0, :] = True # 第一行
border_mask[-1, :] = True # 最后一行
border_mask[:, 0] = True # 第一列
border_mask[:, -1] = True # 最后一列
返回:
loss: 内部墙壁连续区域的平均惩罚损失
"""
# 取出对应墙壁类别的概率图 [B, H, W]
wall_probs = pred[:, wall_class, :, :]
# 提取所有允许类别的概率和 [B, H, W]
allowed_probs = pred[:, allowed_classes, :, :].sum(dim=1)
# 排除最外圈,取内部区域 (H, W 均减去2)
interior = wall_probs[:, 1:-1, 1:-1] # [B, H-2, W-2]
# 获取外圈区域允许类别的概率 [B, N_pixels]
border_allowed = allowed_probs[:, border_mask]
# 构造一个 2×2 的卷积核,全为 1用于检测局部连续墙壁的概率之和
kernel = torch.ones((1, 1, 2, 2), device=pred.device)
# 计算不符合要求的概率(反向损失)
# 1 - 允许类别的概率 = 禁止类别的概率和
border_violation = 1 - border_allowed
# 对内部区域进行卷积操作,计算每个 2×2 区域内的概率和
# 需要将 interior 扩展一个通道维度
conv_result = F.conv2d(interior.unsqueeze(1), kernel, stride=1, padding=0)
# conv_result 的形状为 [B, 1, H-3, W-3]
# 使用平滑的Huber损失替代直接均值
loss = F.huber_loss(
border_violation,
torch.zeros_like(border_violation),
delta=0.1,
reduction='mean'
)
# 对于每个 2×2 区域,如果概率和超过 threshold则产生惩罚
# 这里采用 ReLU 计算超出部分,确保损失为非负
penalty = F.relu(conv_result - threshold)
# 取平均作为损失值
loss = penalty.mean()
return loss
return penalty_scale * loss
def entrance_loss(pred, stairs_class=10, arrow_class=11):
"""
针对地图生成的额外约束损失
- 保证最外圈不出现楼梯类型入口数字10
- 保证内部区域不出现箭头类型入口数字11
参数:
pred: 模型输出形状 [B, C, H, W]
stairs_class: 楼梯入口对应的类别数字10
arrow_class: 箭头入口对应的类别数字11
返回:
loss: 针对入口出现的惩罚损失
"""
# 先将 logits 转为概率分布
B, C, H, W = pred.shape
def _create_distance_kernel(size):
"""生成带距离权重的卷积核"""
kernel = torch.zeros(2*size-1, 2*size-1)
center = size-1
for i in range(2*size-1):
for j in range(2*size-1):
dist = math.sqrt((i-center)**2 + (j-center)**2)
kernel[i,j] = 1 / (1 + dist) # 距离越近权重越高
return kernel.view(1,1,2*size-1,2*size-1)
# 构造最外圈 mask外圈为 True其余为 False
outer_mask = torch.zeros((H, W), dtype=torch.bool, device=pred.device)
outer_mask[0, :] = True
outer_mask[-1, :] = True
outer_mask[:, 0] = True
outer_mask[:, -1] = True
# 内部区域 mask
interior_mask = ~outer_mask # 取反
# 提取对应类别的概率图
stairs_probs = pred[:, stairs_class, :, :] # 楼梯概率 [B, H, W]
arrow_probs = pred[:, arrow_class, :, :] # 箭头概率 [B, H, W]
# 从最外圈提取楼梯概率;用 mask 索引时:张量[:, mask] 会将每个样本的外圈像素展平
outer_stairs = stairs_probs[:, outer_mask] # [B, num_outer_pixels]
# 从内部区域提取箭头概率
interior_arrow = arrow_probs[:, interior_mask] # [B, num_interior_pixels]
# 损失设计:使得这些概率尽量接近 0直接使用均值惩罚
outer_loss = outer_stairs.mean()
interior_loss = interior_arrow.mean()
total_loss = outer_loss + interior_loss
return total_loss
def entrance_distance_and_presence_loss(
pred,
arrow_class=11, stairs_class=10,
arrow_min_threshold=0.5, stairs_min_threshold=0.5,
lambda_arrow_presence=1.0, lambda_stairs_presence=1.0
def entrance_constraint_loss(
pred: torch.Tensor,
entrance_classes=[10, 11], # 假设10是楼梯11是箭头
min_distance=9,
presence_threshold=0.9,
lambda_presence=1.0,
lambda_spacing=0.5
):
"""
入口损失同时考虑
1. 局部距离约束防止同一类型入口过于靠近
2. 存在性约束鼓励至少放置一个入口
箭头入口要求局部 (9x9) 内最多只有一个入口
楼梯入口要求在一个窗口地图尺寸一半内只出现一个楼梯入口
入口约束损失函数
参数:
pred: 模型输出, shape [B, C, H, W]
arrow_class: 箭头入口类别默认 11
stairs_class: 楼梯入口类别默认 10
arrow_min_threshold: 箭头入口全局最小平均概率要求可根据任务调节
stairs_min_threshold: 楼梯入口全局最小平均概率要求
lambda_arrow_presence: 箭头入口存在性损失权重
lambda_stairs_presence: 楼梯入口存在性损失权重
pred: 模型输出的概率分布 [B, C, H, W]
entrance_classes: 入口类别列表
min_distance: 最小间隔距离对应卷积核尺寸
presence_threshold: 存在性概率阈值
lambda_presence: 存在性损失权重
lambda_spacing: 间距约束权重
返回:
total_loss: 综合入口距离与存在性损失
total_loss: 综合损失值
"""
# 将 logits 转换为概率分布
B, C, H, W = pred.shape
entrance_probs = pred[:, entrance_classes].sum(dim=1)
# 提取箭头和楼梯的概率图
arrow_probs = pred[:, arrow_class, :, :] # [B, H, W]
stairs_probs = pred[:, stairs_class, :, :] # [B, H, W]
###########################
# 改进的存在性约束
###########################
# 计算存在性损失:鼓励至少有一个高置信度入口
max_per_sample = entrance_probs.view(B, -1).max(dim=1)[0]
presence_loss = F.relu(presence_threshold - max_per_sample).mean()
#### 局部距离约束 ####
# 箭头:构造 9x9 卷积核,半径 4
kernel_arrow = torch.ones((1, 1, 9, 9), device=pred.device)
local_arrow_sum = F.conv2d(arrow_probs.unsqueeze(1), kernel_arrow, padding=4)
# 减去自身概率,计算多余的局部累积
arrow_excess = local_arrow_sum - arrow_probs.unsqueeze(1)
arrow_distance_loss = F.relu(arrow_excess).mean()
###########################
# 改进的间距约束
###########################
# 生成空间权重掩码(中心衰减)
y, x = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
center_weight = 1 - torch.sqrt(((x-W//2)/W*2)**2 + ((y-H//2)/H*2)**2)
center_weight = center_weight.clamp(0,1).to(pred.device) # [H,W]
# 楼梯:使用窗口大小为 (W//2, H//2)
kernel_size_stairs = (9, 9)
kernel_stairs = torch.ones((1, 1, kernel_size_stairs[0], kernel_size_stairs[1]), device=pred.device)
pad_stairs = ((kernel_size_stairs[0] - 1) // 2, (kernel_size_stairs[1] - 1) // 2)
local_stairs_sum = F.conv2d(stairs_probs.unsqueeze(1), kernel_stairs, padding=pad_stairs)
stairs_excess = local_stairs_sum - stairs_probs.unsqueeze(1)
stairs_distance_loss = F.relu(stairs_excess).mean()
#### 存在性约束 ####
# 计算每个样本中箭头的最大概率
global_arrow_max = arrow_probs.view(B, -1).max(dim=1)[0] # [B]
global_stairs_max = stairs_probs.view(B, -1).max(dim=1)[0] # [B]
# 概率密度感知的间距计算
kernel = _create_distance_kernel(min_distance).to(pred.device) # 自定义函数生成权重核
density_map = F.conv2d(entrance_probs.unsqueeze(1), kernel, padding=min_distance-1)
# 取 batch 平均(或者你可以对每个样本分别计算损失再求平均)
global_arrow_max = global_arrow_max.mean()
global_stairs_max = global_stairs_max.mean()
# 如果全局均值低于预期阈值,则施加额外惩罚
arrow_presence_loss = F.relu(arrow_min_threshold - global_arrow_max)
stairs_presence_loss = F.relu(stairs_min_threshold - global_stairs_max)
# 平滑惩罚函数S形曲线
spacing_loss = torch.sigmoid(10*(density_map - 0.5)).mean() # 密度>0.5时快速上升
ap_weighted = lambda_arrow_presence * arrow_presence_loss
sp_weighted = lambda_stairs_presence * stairs_presence_loss
# print(entrance_probs)
print(presence_loss.item(), (density_map).mean().item(), center_weight.mean().item())
# 总入口损失:局部距离约束 + 存在性约束(加权)
total_loss = arrow_distance_loss + stairs_distance_loss \
+ min(ap_weighted, sp_weighted)
###########################
# 区域加权综合损失
###########################
total_loss = (
lambda_presence * presence_loss +
lambda_spacing * (spacing_loss * center_weight).mean()
)
return total_loss
def monster_consecutive_loss(pred, monster_classes=[7,8,9], threshold=2.9):
def adaptive_count_loss(
pred_probs: torch.Tensor,
target_map: torch.Tensor,
class_list: list = list(range(32)),
margin_ratio: float = 0.2,
zero_margin_scale: float = 0.3,
eps: float = 1e-6
) -> torch.Tensor:
"""
检查横向和纵向是否存在连续超过三个的怪物类别 7,8,9
自适应图块数量约束损失函数
参数:
pred: 模型输出形状 [B, C, H, W]
monster_classes: 待检测的怪物类别列表
threshold: 滑动窗口内概率和的阈值若超过则施加惩罚
对于连续三个像素如果每个像素概率接近 1则窗口和接近 3
pred_probs: 预测概率分布 [B, C, H, W]
target_map: 真实地图 [B, C, H, W]
class_list: 需要约束的类别列表
margin_ratio: 允许的相对误差范围如0.2表示±20%
zero_margin_scale: 参考数量为0时的允许余量系数余量=scale*sqrt(H*W))
eps: 数值稳定性常数
返回:
loss: 惩罚损失数值越高表示连续怪物区域越严重
"""
# 将 logits 转换为概率分布
B, C, H, W = pred.shape
# 得到怪物整体概率图:将类别 7,8,9 的概率相加
monster_probs = pred[:, monster_classes, :].sum(dim=1) # [B, H, W]
# 注意monster_probs 越高说明该像素更有可能是怪物
# --- 横向检测 ---
# 构造一个 (1,3) 的卷积核,全 1
kernel_horiz = torch.ones((1, 1, 1, 3), device=pred.device)
# 对 monster_probs 加一个 channel 维度,使形状为 [B, 1, H, W]
conv_horiz = F.conv2d(monster_probs.unsqueeze(1), kernel_horiz, padding=(0,1))
# conv_horiz 的每个值表示相邻三个像素的怪物概率和
# --- 纵向检测 ---
# 构造一个 (3,1) 的卷积核,全 1
kernel_vert = torch.ones((1, 1, 3, 1), device=pred.device)
conv_vert = F.conv2d(monster_probs.unsqueeze(1), kernel_vert, padding=(1,0))
# conv_vert 的每个值表示垂直连续三个像素的怪物概率和
# 对两个方向的窗口,如果概率和超过阈值,则计算超出部分的惩罚
penalty_horiz = F.relu(conv_horiz - threshold)
penalty_vert = F.relu(conv_vert - threshold)
# 将两个方向的惩罚损失取平均(或者直接相加)
loss = penalty_horiz.mean() + penalty_vert.mean()
return loss
def illegal_block_loss(pred, used_classes=12, mode='mean'):
"""
对未使用类别例如 12 ~ 31的预测概率施加惩罚
鼓励模型输出仅集中在 0 ~ 11
参数:
pred: 模型输出形状 [B, num_classes, H, W]
used_classes: 已经使用的类别数例如 12 表示只使用 0-11
mode: 'mean' 使用平均概率 'mse' 使用均方误差
返回:
penalty: 标量惩罚损失
"""
B, C, H, W = pred.shape
# 选取非法类别的概率(注意:这一步会得到非法图块在每个像素上的概率)
illegal_probs = pred[:, range(used_classes, 32), :, :] # [B, len(illegal_classes), H, W]
# 我们可以将非法图块的概率在类别维度上求和,得到每个像素的非法激活值
illegal_activation = illegal_probs.sum(dim=1) # [B, H, W]
# 接下来我们计算整个图上非法激活的“数量”
# 例如,可以直接对整个 batch 内非法激活求和
total_illegal = illegal_activation.sum() / B # 标量
# 计算损失值:使用负指数函数。注意如果非法激活很小,总损失接近 exp(0)=1
loss = torch.sqrt(total_illegal).mean()
return loss
def integrated_count_loss(probs, target, class_list=[0,1,2,3,4,5,6,7,8,9], tolerance=0.5):
"""
对每个类别分别计算数量匹配损失再取平均
参数:
probs: 模型输出的概率形状 [B, num_classes, H, W]
target: 真实标签形状 [B, H, W]类别取值在 0 ~ 使用范围-1
class_list: 需要计算的类别列表
tolerance: 每个类别允许的相对误差例如 0.15 表示 15%
返回:
loss: 对每个类别数量匹配损失取平均后的标量
loss: 标量损失值
"""
B, C, H, W = pred_probs.shape
device = pred_probs.device
total_loss = 0.0
count = 0
B, C, H, W = probs.shape
valid_classes = 0
# 预计算地图面积用于余量计算
map_area = math.sqrt(H * W)
for cls in class_list:
# 预测数量:对于当前类别,所有像素的预测概率和
pred_count = probs[:, cls, :, :].sum()
# 真实数量:统计 target 中属于当前类别的像素数量
true_count = (target == cls).float().sum()
# 预测数量(概率和)
pred_count = pred_probs[:, cls].sum(dim=(1,2)) # [B]
# 真实数量
true_count = target_map[:, cls].sum(dim=(1,2)) # [B]
if true_count == 0:
# 参考地图中不包含该类别,允许最多出现 (sqrt(地图尺寸) / 2) 个单位的概率输出
cls_loss = F.relu(pred_count - math.sqrt(H * W) / 2)
else:
# 计算相对误差
rel_error = torch.abs(pred_count - true_count) / (true_count)
cls_loss = F.relu(rel_error - tolerance)
# 动态容差计算
with torch.no_grad():
# 当真实数量为0时的允许上限
zero_mask = (true_count == 0)
dynamic_margin = torch.where(
zero_mask,
zero_margin_scale * map_area, # 允许存在少量
margin_ratio * true_count # 相对误差范围
)
total_loss += cls_loss
count += 1
# 误差计算(考虑数值稳定性)
safe_true = true_count + eps * zero_mask # 零真实值时添加微小量
abs_error = torch.abs(pred_count - true_count)
rel_error = abs_error / safe_true
# 求平均每个类别的损失
avg_loss = total_loss / count
return avg_loss
# 双阶段损失函数
# 阶段一:误差在容差范围内时使用二次函数(强梯度)
# 阶段二:超出容差时转为线性(稳定训练)
loss_per_class = torch.where(
abs_error <= dynamic_margin,
(rel_error ** 2) * 0.5, # 区间内强梯度
rel_error - (0.5 * margin_ratio) # 区间外稳定梯度
)
# 零真实值特殊处理:仅惩罚超出余量部分
loss_per_class = torch.where(
zero_mask,
F.relu(abs_error - dynamic_margin) / map_area, # 归一化处理
loss_per_class
)
total_loss += loss_per_class.mean()
valid_classes += 1
return total_loss / valid_classes # 类别平均
def illegal_tile_loss(
pred_probs: torch.Tensor,
legal_classes: int = 13,
temperature: float = 0.1,
eps: float = 1e-8
) -> torch.Tensor:
"""
非法图块惩罚损失函数
参数:
pred_probs: 模型输出的概率分布 [B, C, H, W]
legal_classes: 合法图块数量0-based默认0-12为合法
temperature: 概率锐化温度系数0.1-1.0
eps: 数值稳定性保护
返回:
loss: 标量损失值
"""
B, C, H, W = pred_probs.shape
# 提取非法图块概率类别13及之后
illegal_probs = pred_probs[:, legal_classes:, :, :] # [B, C_illegal, H, W]
# 概率锐化(增强高概率区域的惩罚)
sharpened_probs = torch.exp(torch.log(illegal_probs + eps) / temperature)
sharpened_probs = sharpened_probs / (sharpened_probs.sum(dim=1, keepdim=True) + eps)
# 空间敏感权重(关注高置信度非法区域)
with torch.no_grad():
# 计算每个像素的非法概率置信度
confidence = illegal_probs.max(dim=1)[0] # [B, H, W]
# 生成注意力权重(高置信度区域权重加倍)
spatial_weights = 1 + torch.sigmoid(10*(confidence - 0.5))
# 逐像素计算非法概率损失
per_pixel_loss = torch.log(1 + illegal_probs.sum(dim=1)) # [B, H, W]
# 加权空间损失
weighted_loss = (per_pixel_loss * spatial_weights).mean()
# 类别平衡因子(抑制高频非法类别)
class_balance = 1 + torch.var(illegal_probs.mean(dim=(0,2,3))) # [C_illegal]
return weighted_loss * class_balance.mean()
def entrance_spatial_constraint(
pred_probs: torch.Tensor,
arrow_class: int = 11,
stair_class: int = 10,
border_width: int = 1,
lambda_arrow: float = 1.0,
lambda_stair: float = 1.0
) -> torch.Tensor:
"""
入口空间约束损失函数
参数:
pred_probs: 模型输出的概率分布 [B, C, H, W]
arrow_class: 箭头入口类别索引
stair_class: 楼梯入口类别索引
border_width: 边缘区域宽度默认1表示最外圈
lambda_arrow: 箭头约束权重
lambda_stair: 楼梯约束权重
返回:
loss: 标量损失值
"""
B, C, H, W = pred_probs.shape
##########################################
# 1. 区域掩码生成
##########################################
# 生成边缘区域掩码 [H, W]
edge_mask = torch.zeros((H, W), dtype=torch.bool, device=pred_probs.device)
# 上下边缘
edge_mask[:border_width, :] = True
edge_mask[-border_width:, :] = True
# 左右边缘(排除已标记的角落)
edge_mask[:, :border_width] = True
edge_mask[:, -border_width:] = True
# 生成中间区域掩码 [H, W]
center_mask = ~edge_mask
##########################################
# 2. 边缘区域约束(只能出现箭头)
##########################################
# 提取边缘区域的箭头概率 [B, N_edge_pixels]
edge_arrow_probs = pred_probs[:, arrow_class][:, edge_mask]
# 边缘应最大化箭头概率最小化1 - arrow_prob
edge_arrow_loss = (1 - edge_arrow_probs).mean()
# 抑制边缘出现楼梯的概率 [B, N_edge_pixels]
edge_stair_probs = pred_probs[:, stair_class][:, edge_mask]
edge_stair_penalty = F.relu(edge_stair_probs - 0.1).mean() # 允许10%以下
##########################################
# 3. 中间区域约束(只能出现楼梯)
##########################################
# 提取中间区域的楼梯概率 [B, N_center_pixels]
center_stair_probs = pred_probs[:, stair_class][:, center_mask]
# 中间应最大化楼梯概率最小化1 - stair_prob
center_stair_loss = (1 - center_stair_probs).mean()
# 抑制中间出现箭头的概率 [B, N_center_pixels]
center_arrow_probs = pred_probs[:, arrow_class][:, center_mask]
center_arrow_penalty = F.relu(center_arrow_probs - 0.1).mean() # 允许10%以下
##########################################
# 4. 综合损失
##########################################
total_loss = (
lambda_arrow * (edge_arrow_loss + edge_stair_penalty) +
lambda_stair * (center_stair_loss + center_arrow_penalty)
)
return total_loss
class GinkaLoss(nn.Module):
def __init__(self, minamo: MinamoModel, weight=[0.35, 0.1, 0.1, 0.1, 0.1, 0.05, 0.1, 0.1]):
def __init__(self, minamo: MinamoModel, weight=[0.5, 0.15, 0.15, 0.1, 0.1]):
"""Ginka Model 损失函数部分
Args:
weight (list, optional): 每一个损失函数的权重从第 0 项开始依次是
1. 拓扑图损失
1. Minamo 相似度损失
2. 外圈墙壁损失
3. 内层 2*2 墙壁损失
4. 要求外层只能有箭头内层只能有楼梯的损失
5. 入口间距及存在性损失
6. 连续怪物损失
7. 非法图块损失
8. 怪物道具门数量损失
3. 入口间距及存在性损失
4. 怪物道具门数量损失
5. 非法图块损失
"""
super().__init__()
self.weight = weight
@ -304,41 +325,37 @@ class GinkaLoss(nn.Module):
def forward(self, pred, target, target_vision_feat, target_topo_feat):
# 地图结构损失
border_loss = wall_border_loss(pred)
wall_loss = internal_wall_loss(pred)
entry_loss = entrance_loss(pred)
entry_dis_loss = entrance_distance_and_presence_loss(pred, )
enemy_loss = monster_consecutive_loss(pred)
valid_block_loss = illegal_block_loss(pred, used_classes=12, mode="mean")
count_loss = integrated_count_loss(pred, target)
border_loss = outer_border_constraint_loss(pred)
entrance_loss = entrance_constraint_loss(pred) * 0.5 + entrance_spatial_constraint(pred) * 0.5
count_loss = adaptive_count_loss(pred, target)
illegal_loss = illegal_tile_loss(pred)
# 使用 Minamo Model 计算相似度
graph = convert_soft_map_to_graph(pred)
graph = batch_convert_soft_map_to_graph(pred)
pred_vision_feat, pred_topo_feat = self.minamo(pred, graph)
vision_sim = F.cosine_similarity(pred_vision_feat, target_vision_feat, dim=-1)
topo_sim = F.cosine_similarity(pred_topo_feat, target_topo_feat, dim=-1)
minamo_sim = 0.3 * vision_sim + 0.7 * topo_sim
minamo_loss = torch.exp(-1 * (minamo_sim - 0.8)).mean()
minamo_loss = (1.0 - minamo_sim).mean()
print(
minamo_loss.item(),
border_loss.item(),
wall_loss.item(),
entry_loss.item(),
entry_dis_loss.item(),
enemy_loss.item(),
valid_block_loss.item(),
count_loss.item()
entrance_loss.item(),
count_loss.item(),
illegal_loss.item()
)
return (
minamo_loss * self.weight[0] +
border_loss * self.weight[1] +
wall_loss * self.weight[2] +
entry_loss * self.weight[3] +
entry_dis_loss * self.weight[4] +
enemy_loss * self.weight[5] +
valid_block_loss * self.weight[6] +
count_loss * self.weight[7]
)
losses = [
minamo_loss * self.weight[0],
border_loss * self.weight[1] * 0.1,
entrance_loss * self.weight[2],
count_loss * self.weight[3],
illegal_loss * self.weight[4]
]
# 梯度归一化
scaled_losses = [loss / (loss.detach() + 1e-6) for loss in losses]
total_loss = sum(scaled_losses)
return total_loss

View File

@ -4,7 +4,7 @@ import torch.nn.functional as F
from .unet import GinkaUNet
class GinkaModel(nn.Module):
def __init__(self, feat_dim=256, base_ch=64, num_classes=32):
def __init__(self, feat_dim=256, base_ch=128, num_classes=32):
"""Ginka Model 模型定义部分
"""
super().__init__()
@ -25,5 +25,5 @@ class GinkaModel(nn.Module):
x = x.view(-1, self.base_ch, 32, 32)
x = self.unet(x)
x = F.interpolate(x, (13, 13), mode='bilinear', align_corners=False)
return F.softmax(x)
return F.softmax(x, dim=1)

View File

@ -46,13 +46,13 @@ def train():
)
# 设定优化器与调度器
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
criterion = GinkaLoss(minamo, weight=[1, 0, 0, 0, 0, 0, 0, 0])
criterion = GinkaLoss(minamo)
model.register_full_backward_hook(grad_hook)
# model.register_full_backward_hook(grad_hook)
# converter.register_full_backward_hook(grad_hook)
criterion.register_full_backward_hook(grad_hook)
# criterion.register_full_backward_hook(grad_hook)
# 开始训练
for epoch in tqdm(range(epochs)):

View File

@ -1,5 +1,5 @@
import torch
from torch_geometric.data import Data
from torch_geometric.data import Data, Batch
def convert_soft_map_to_graph(map_probs: torch.Tensor):
"""
@ -31,3 +31,17 @@ def convert_soft_map_to_graph(map_probs: torch.Tensor):
map_probs[:, edge_index[1] // W, edge_index[1] % W]) / 2
return Data(x=node_features, edge_index=edge_index, edge_attr=soft_edge_weight)
def batch_convert_soft_map_to_graph(batch_map_probs):
"""
处理 batch 维度 [B, C, H, W] 转换为批量图结构 Batch
"""
B, C, H, W = batch_map_probs.shape # 获取 batch 维度
batch_graphs = []
for i in range(B):
graph = convert_soft_map_to_graph(batch_map_probs[i]) # 处理单个样本
batch_graphs.append(graph)
# 合并所有图为批量 Batch
return Batch.from_data_list(batch_graphs)