feat: 修改 ginka 模型的损失值计算

This commit is contained in:
unanmed 2025-03-31 13:46:33 +08:00
parent 4f7dbb6fb3
commit 96f828e29b
5 changed files with 114 additions and 79 deletions

View File

@ -5,7 +5,22 @@ import torch.nn.functional as F
from minamo.model.model import MinamoModel
from shared.graph import batch_convert_soft_map_to_graph
def outer_border_constraint_loss(pred: torch.Tensor, allowed_classes=[1, 11], penalty_scale=1.0):
CLASS_NUM = 32
ILLEGAL_MAX_NUM = 12
def get_not_allowed(classes: list[int], include_illegal=False):
res = list()
for num in range(0, CLASS_NUM):
if not num in classes:
if num > ILLEGAL_MAX_NUM:
if include_illegal:
res.append(num)
else:
res.append(num)
return res
def outer_border_constraint_loss(pred: torch.Tensor, allowed_classes=[1, 11]):
"""
强制地图最外圈像素必须为指定类别墙或箭头
@ -26,40 +41,58 @@ def outer_border_constraint_loss(pred: torch.Tensor, allowed_classes=[1, 11], pe
border_mask[:, 0] = True # 第一列
border_mask[:, -1] = True # 最后一列
# 提取所有允许类别的概率和 [B, H, W]
allowed_probs = pred[:, allowed_classes, :, :].sum(dim=1)
# 提取所有允许和不允许类别的概率和 [B, H, W]
unallowed_probs = pred[:, get_not_allowed(allowed_classes, include_illegal=True), :, :].sum(dim=1)
# 获取外圈区域允许类别的概率 [B, N_pixels]
border_allowed = allowed_probs[:, border_mask]
border_unallowed = unallowed_probs[:, border_mask]
# 计算不符合要求的概率(反向损失)
# 1 - 允许类别的概率 = 禁止类别的概率和
border_violation = 1 - border_allowed
target = torch.zeros_like(border_unallowed)
loss_unallowed = F.mse_loss(border_unallowed, target)
# 使用平滑的Huber损失替代直接均值
loss = F.huber_loss(
border_violation,
torch.zeros_like(border_violation),
delta=0.1,
reduction='mean'
)
return loss_unallowed
return penalty_scale * loss
def inner_constraint_loss(pred: torch.Tensor, allowed=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12]):
"""限定内部允许出现的图块种类
Args:
pred (torch.Tensor): 模型输出的概率分布 [B, C, H, W]
unallowed (list, optional): 在地图中部处最外圈允许出现的图块种类. Defaults to [11].
"""
B, C, H, W = pred.shape
# 创建内部 mask [H, W]
mask = torch.ones((H, W), dtype=torch.bool, device=pred.device)
mask[0, :] = False # 第一行
mask[-1, :] = False # 最后一行
mask[:, 0] = False # 第一列
mask[:, -1] = False # 最后一列
# 提取所有允许和不允许类别的概率和 [B, H, W]
unallowed_probs = pred[:, get_not_allowed(allowed, include_illegal=True), :, :].sum(dim=1)
# 获取外圈区域允许类别的概率 [B, N_pixels]
inner_unallowed = unallowed_probs[:, mask]
target = torch.zeros_like(inner_unallowed)
loss_unallowed = F.mse_loss(inner_unallowed, target)
return loss_unallowed
def _create_distance_kernel(size):
"""生成一个环状衰减核"""
y, x = torch.meshgrid(torch.arange(size), torch.arange(size), indexing='ij')
center = size // 2
dist = torch.sqrt((x - center)**2 + (y - center)**2)
kernel = torch.exp(-dist / (size / 2)) # 高斯衰减
kernel = 1 / (dist + 1)
kernel /= kernel.sum() # 归一化
return kernel.unsqueeze(0).unsqueeze(0) # [1,1,H,W]
return kernel.unsqueeze(0).unsqueeze(0), 1 / kernel.sum() # [1,1,H,W]
def entrance_constraint_loss(
pred: torch.Tensor,
entrance_classes=[10, 11], # 假设10是楼梯11是箭头
min_distance=9,
presence_threshold=0.9,
presence_threshold=0.8,
lambda_presence=1.0,
lambda_spacing=0.5
):
@ -78,29 +111,23 @@ def entrance_constraint_loss(
total_loss: 综合损失值
"""
B, C, H, W = pred.shape
entrance_probs = pred[:, entrance_classes].sum(dim=1)
entrance_probs = pred[:, entrance_classes, :, :].sum(dim=1) # [B, H, W]
###########################
# 改进的存在性约束
###########################
# 计算存在性损失:鼓励至少有一个高置信度入口
max_per_sample = entrance_probs.view(B, -1).max(dim=1)[0]
max_per_sample = entrance_probs.view(B, -1).max(dim=1)[0] # [B, H*W] -> [B, 1]
presence_loss = F.relu(presence_threshold - max_per_sample).mean()
###########################
# 改进的间距约束
###########################
# 生成空间权重掩码(中心衰减)
y, x = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
center_weight = 1 - torch.sqrt(((x-W//2)/W*2)**2 + ((y-H//2)/H*2)**2)
center_weight = center_weight.clamp(0,1).to(pred.device) # [H,W]
# 概率密度感知的间距计算
kernel = _create_distance_kernel(min_distance).to(pred.device) # 自定义函数生成权重核
kernel, cw = _create_distance_kernel(min_distance) # 自定义函数生成权重核
kernel = kernel.to(pred.device)
density_map = F.conv2d(entrance_probs.unsqueeze(1), kernel, padding=min_distance-1)
# 平滑惩罚函数S形曲线
spacing_loss = torch.sigmoid(10*(density_map - 0.5)).mean() # 密度>0.5时快速上升
spacing_loss = density_map.mean()
###########################
# 区域加权综合损失
@ -115,16 +142,16 @@ def adaptive_count_loss(
pred_probs: torch.Tensor,
target_map: torch.Tensor,
class_list: list = list(range(32)),
margin_ratio: float = 0.2,
zero_margin_scale: float = 0.2,
lambda_entropy: float = 0.05,
lambda_local: float = 0.1,
grid_size: int = 8,
margin_ratio: float = 0.1, # 降低margin比例以更严格
zero_margin_scale: float = 0.1, # 减少零类别的margin
lambda_entropy: float = 0.2, # 增大熵约束权重
lambda_local: float = 0.2,
lambda_max: float = 0, # 新增最大概率约束
grid_size: int = 4, # 减小局部网格尺寸
eps: float = 1e-3
) -> torch.Tensor:
"""
改进版自适应图块数量约束损失包含局部匹配和熵约束
改进版自适应图块数量约束损失增强局部匹配和概率确定性
"""
B, C, H, W = pred_probs.shape
device = pred_probs.device
@ -134,13 +161,14 @@ def adaptive_count_loss(
# 预计算地图面积
map_area = math.sqrt(H * W)
# 计算最小非零类别概率
min_nonzero_prob = pred_probs[:, class_list].max(dim=1).values.mean() # 获取预测中的最小非零概率
dynamic_zero_margin = zero_margin_scale * min_nonzero_prob * map_area # 让零类别不被填充
# 动态调整零类别的margin基于预测中最小的非零概率
min_nonzero_prob = pred_probs[:, class_list].max(dim=1).values.mean()
dynamic_zero_margin = zero_margin_scale * min_nonzero_prob * map_area
# 计算每个类别的数量损失
for cls in class_list:
pred_count = pred_probs[:, cls].sum(dim=(1,2)) # 预测类别数量
true_count = target_map[:, cls].sum(dim=(1,2)) # 真实类别数量
pred_count = pred_probs[:, cls].sum(dim=(1,2)) # 预测数量
true_count = target_map[:, cls].sum(dim=(1,2)) # 真实数量
zero_mask = (true_count == 0)
dynamic_margin = torch.where(
@ -153,37 +181,45 @@ def adaptive_count_loss(
abs_error = torch.abs(pred_count - true_count)
rel_error = abs_error / safe_true
# 调整损失函数形状,远离目标时惩罚更大
loss_per_class = torch.where(
abs_error <= dynamic_margin,
(rel_error ** 2) * 0.8 + 0.2 * rel_error,
rel_error - (0.5 * margin_ratio)
rel_error ** 2, # 近目标时二次损失
(rel_error - 0.5 * margin_ratio) ** 2 # 远目标时二次增长
)
# 零类别使用更严格的绝对误差惩罚
loss_per_class = torch.where(
zero_mask,
F.relu(abs_error - dynamic_margin) / map_area,
F.relu(abs_error - dynamic_zero_margin) ** 2 / map_area,
loss_per_class
)
total_loss += loss_per_class.mean()
valid_classes += 1
# 平均类别损失
total_loss /= valid_classes
total_loss /= valid_classes # 平均类别损失
# 加入负熵约束,防止类别均匀化
# 改进的熵约束:每个像素的熵
def entropy_loss(pred_probs):
avg_probs = pred_probs.mean(dim=(2, 3))
entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-6), dim=1)
return entropy.mean()
entropy_per_pixel = -torch.sum(pred_probs * torch.log(pred_probs + 1e-6), dim=1)
return entropy_per_pixel.mean() # 所有像素的平均熵
total_loss += lambda_entropy * entropy_loss(pred_probs)
# 加入局部类别匹配
def local_count_loss(pred_probs, target_probs, grid_size=8):
pred_local = F.avg_pool2d(pred_probs, kernel_size=grid_size, stride=grid_size)
target_local = F.avg_pool2d(target_probs, kernel_size=grid_size, stride=grid_size)
return F.mse_loss(pred_local, target_local)
# 新增最大概率约束:鼓励每个位置概率尖锐化
max_probs = pred_probs.max(dim=1)[0] # 每个位置的最大概率
max_loss = (1 - max_probs).mean() # 鼓励接近1
total_loss += lambda_max * max_loss
# 改进局部损失:约束局部区域内的数量
def local_count_loss(pred_probs, target_probs, grid_size):
grid_area = grid_size ** 2
# 计算每个grid内的预测数量
pred_counts = F.avg_pool2d(pred_probs, grid_size, stride=grid_size) * grid_area
target_counts = F.avg_pool2d(target_probs, grid_size, stride=grid_size) * grid_area
# 使用L1损失更鲁棒
return F.l1_loss(pred_counts, target_counts)
total_loss += lambda_local * local_count_loss(pred_probs, target_map, grid_size)
@ -300,16 +336,15 @@ def entrance_spatial_constraint(
return total_loss
class GinkaLoss(nn.Module):
def __init__(self, minamo: MinamoModel, weight=[0.5, 0.1, 0.1, 0.2, 0.1]):
def __init__(self, minamo: MinamoModel, weight=[0.5, 0.2, 0.1, 0.2]):
"""Ginka Model 损失函数部分
Args:
weight (list, optional): 每一个损失函数的权重从第 0 项开始依次是
1. Minamo 相似度损失
2. 外圈墙壁损失
2. 图块种类损失要求内部不出现箭头外圈只出现箭头和墙壁不允许出现非法图块
3. 入口间距及存在性损失
4. 怪物道具门数量损失
5. 非法图块损失
"""
super().__init__()
self.weight = weight
@ -317,10 +352,9 @@ class GinkaLoss(nn.Module):
def forward(self, pred, target, target_vision_feat, target_topo_feat):
# 地图结构损失
border_loss = outer_border_constraint_loss(pred)
entrance_loss = entrance_constraint_loss(pred) * 0.5 + entrance_spatial_constraint(pred) * 0.5
class_loss = outer_border_constraint_loss(pred) + inner_constraint_loss(pred)
entrance_loss = entrance_constraint_loss(pred) + entrance_spatial_constraint(pred)
count_loss = adaptive_count_loss(pred, target)
illegal_loss = illegal_tile_loss(pred)
# 使用 Minamo Model 计算相似度
graph = batch_convert_soft_map_to_graph(pred)
@ -328,23 +362,21 @@ class GinkaLoss(nn.Module):
vision_sim = F.cosine_similarity(pred_vision_feat, target_vision_feat, dim=-1)
topo_sim = F.cosine_similarity(pred_topo_feat, target_topo_feat, dim=-1)
minamo_sim = 0.3 * vision_sim + 0.7 * topo_sim
minamo_sim = 0.2 * vision_sim + 0.8 * topo_sim
minamo_loss = (1.0 - minamo_sim).mean()
print(
minamo_loss.item(),
border_loss.item(),
class_loss.item(),
entrance_loss.item(),
count_loss.item(),
illegal_loss.item()
count_loss.item()
)
losses = [
minamo_loss * self.weight[0],
border_loss * self.weight[1],
class_loss * self.weight[1],
entrance_loss * self.weight[2],
count_loss * self.weight[3],
illegal_loss * self.weight[4]
count_loss * self.weight[3]
]
# 梯度归一化

View File

@ -8,12 +8,12 @@ def print_memory(tag=""):
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
class GinkaModel(nn.Module):
def __init__(self, feat_dim=1024, base_ch=64, num_classes=32):
def __init__(self, feat_dim=1024, base_ch=64, out_ch=32):
"""Ginka Model 模型定义部分
"""
super().__init__()
self.unet = GinkaUNet(1, base_ch, num_classes, feat_dim)
self.output = GinkaOutput(num_classes, (13, 13))
self.unet = GinkaUNet(1, base_ch, out_ch, feat_dim)
self.output = GinkaOutput(out_ch, out_ch, (13, 13))
def forward(self, x, feat):
"""

View File

@ -2,9 +2,12 @@ import torch
import torch.nn as nn
class GinkaOutput(nn.Module):
def __init__(self, num_classes=32, out_size=(13, 13)):
def __init__(self, out_ch=32, base_ch=64, out_size=(13, 13)):
super().__init__()
self.pool = nn.AdaptiveAvgPool2d(out_size)
self.conv_down = nn.Sequential(
nn.AdaptiveMaxPool2d(out_size),
nn.Conv2d(base_ch, out_ch, 1)
)
def forward(self, x):
return self.pool(x)
return self.conv_down(x)

View File

@ -49,7 +49,7 @@ def train():
if args.resume:
data = torch.load(args.from_state, map_location=device)
model.load_state_dict(data["model_state"])
model.load_state_dict(data["model_state"], strict=False)
if args.load_optim:
optimizer.load_state_dict(data["optimizer_state"])
print("Train from loaded state.")

View File

@ -1,7 +1,7 @@
import torch.nn as nn
class MinamoLoss(nn.Module):
def __init__(self, vision_weight=0.4, topo_weight=0.6):
def __init__(self, vision_weight=0.2, topo_weight=0.8):
super().__init__()
self.vision_weight = vision_weight
self.topo_weight = topo_weight