mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 修改 ginka 模型的损失值计算
This commit is contained in:
parent
4f7dbb6fb3
commit
96f828e29b
@ -5,7 +5,22 @@ import torch.nn.functional as F
|
||||
from minamo.model.model import MinamoModel
|
||||
from shared.graph import batch_convert_soft_map_to_graph
|
||||
|
||||
def outer_border_constraint_loss(pred: torch.Tensor, allowed_classes=[1, 11], penalty_scale=1.0):
|
||||
CLASS_NUM = 32
|
||||
ILLEGAL_MAX_NUM = 12
|
||||
|
||||
def get_not_allowed(classes: list[int], include_illegal=False):
|
||||
res = list()
|
||||
for num in range(0, CLASS_NUM):
|
||||
if not num in classes:
|
||||
if num > ILLEGAL_MAX_NUM:
|
||||
if include_illegal:
|
||||
res.append(num)
|
||||
else:
|
||||
res.append(num)
|
||||
|
||||
return res
|
||||
|
||||
def outer_border_constraint_loss(pred: torch.Tensor, allowed_classes=[1, 11]):
|
||||
"""
|
||||
强制地图最外圈像素必须为指定类别(墙或箭头)
|
||||
|
||||
@ -26,40 +41,58 @@ def outer_border_constraint_loss(pred: torch.Tensor, allowed_classes=[1, 11], pe
|
||||
border_mask[:, 0] = True # 第一列
|
||||
border_mask[:, -1] = True # 最后一列
|
||||
|
||||
# 提取所有允许类别的概率和 [B, H, W]
|
||||
allowed_probs = pred[:, allowed_classes, :, :].sum(dim=1)
|
||||
# 提取所有允许和不允许类别的概率和 [B, H, W]
|
||||
unallowed_probs = pred[:, get_not_allowed(allowed_classes, include_illegal=True), :, :].sum(dim=1)
|
||||
|
||||
# 获取外圈区域允许类别的概率 [B, N_pixels]
|
||||
border_allowed = allowed_probs[:, border_mask]
|
||||
border_unallowed = unallowed_probs[:, border_mask]
|
||||
|
||||
# 计算不符合要求的概率(反向损失)
|
||||
# 1 - 允许类别的概率 = 禁止类别的概率和
|
||||
border_violation = 1 - border_allowed
|
||||
target = torch.zeros_like(border_unallowed)
|
||||
loss_unallowed = F.mse_loss(border_unallowed, target)
|
||||
|
||||
# 使用平滑的Huber损失替代直接均值
|
||||
loss = F.huber_loss(
|
||||
border_violation,
|
||||
torch.zeros_like(border_violation),
|
||||
delta=0.1,
|
||||
reduction='mean'
|
||||
)
|
||||
return loss_unallowed
|
||||
|
||||
return penalty_scale * loss
|
||||
def inner_constraint_loss(pred: torch.Tensor, allowed=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12]):
|
||||
"""限定内部允许出现的图块种类
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): 模型输出的概率分布 [B, C, H, W]
|
||||
unallowed (list, optional): 在地图中部(处最外圈)允许出现的图块种类. Defaults to [11].
|
||||
"""
|
||||
B, C, H, W = pred.shape
|
||||
|
||||
# 创建内部 mask [H, W]
|
||||
mask = torch.ones((H, W), dtype=torch.bool, device=pred.device)
|
||||
mask[0, :] = False # 第一行
|
||||
mask[-1, :] = False # 最后一行
|
||||
mask[:, 0] = False # 第一列
|
||||
mask[:, -1] = False # 最后一列
|
||||
|
||||
# 提取所有允许和不允许类别的概率和 [B, H, W]
|
||||
unallowed_probs = pred[:, get_not_allowed(allowed, include_illegal=True), :, :].sum(dim=1)
|
||||
|
||||
# 获取外圈区域允许类别的概率 [B, N_pixels]
|
||||
inner_unallowed = unallowed_probs[:, mask]
|
||||
|
||||
target = torch.zeros_like(inner_unallowed)
|
||||
loss_unallowed = F.mse_loss(inner_unallowed, target)
|
||||
|
||||
return loss_unallowed
|
||||
|
||||
def _create_distance_kernel(size):
|
||||
"""生成一个环状衰减核"""
|
||||
y, x = torch.meshgrid(torch.arange(size), torch.arange(size), indexing='ij')
|
||||
center = size // 2
|
||||
dist = torch.sqrt((x - center)**2 + (y - center)**2)
|
||||
kernel = torch.exp(-dist / (size / 2)) # 高斯衰减
|
||||
kernel = 1 / (dist + 1)
|
||||
kernel /= kernel.sum() # 归一化
|
||||
return kernel.unsqueeze(0).unsqueeze(0) # [1,1,H,W]
|
||||
return kernel.unsqueeze(0).unsqueeze(0), 1 / kernel.sum() # [1,1,H,W]
|
||||
|
||||
def entrance_constraint_loss(
|
||||
pred: torch.Tensor,
|
||||
entrance_classes=[10, 11], # 假设10是楼梯,11是箭头
|
||||
min_distance=9,
|
||||
presence_threshold=0.9,
|
||||
presence_threshold=0.8,
|
||||
lambda_presence=1.0,
|
||||
lambda_spacing=0.5
|
||||
):
|
||||
@ -78,29 +111,23 @@ def entrance_constraint_loss(
|
||||
total_loss: 综合损失值
|
||||
"""
|
||||
B, C, H, W = pred.shape
|
||||
entrance_probs = pred[:, entrance_classes].sum(dim=1)
|
||||
entrance_probs = pred[:, entrance_classes, :, :].sum(dim=1) # [B, H, W]
|
||||
|
||||
###########################
|
||||
# 改进的存在性约束
|
||||
###########################
|
||||
# 计算存在性损失:鼓励至少有一个高置信度入口
|
||||
max_per_sample = entrance_probs.view(B, -1).max(dim=1)[0]
|
||||
max_per_sample = entrance_probs.view(B, -1).max(dim=1)[0] # [B, H*W] -> [B, 1]
|
||||
presence_loss = F.relu(presence_threshold - max_per_sample).mean()
|
||||
|
||||
###########################
|
||||
# 改进的间距约束
|
||||
###########################
|
||||
# 生成空间权重掩码(中心衰减)
|
||||
y, x = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
|
||||
center_weight = 1 - torch.sqrt(((x-W//2)/W*2)**2 + ((y-H//2)/H*2)**2)
|
||||
center_weight = center_weight.clamp(0,1).to(pred.device) # [H,W]
|
||||
|
||||
# 概率密度感知的间距计算
|
||||
kernel = _create_distance_kernel(min_distance).to(pred.device) # 自定义函数生成权重核
|
||||
kernel, cw = _create_distance_kernel(min_distance) # 自定义函数生成权重核
|
||||
kernel = kernel.to(pred.device)
|
||||
density_map = F.conv2d(entrance_probs.unsqueeze(1), kernel, padding=min_distance-1)
|
||||
|
||||
# 平滑惩罚函数:S形曲线
|
||||
spacing_loss = torch.sigmoid(10*(density_map - 0.5)).mean() # 密度>0.5时快速上升
|
||||
spacing_loss = density_map.mean()
|
||||
|
||||
###########################
|
||||
# 区域加权综合损失
|
||||
@ -115,16 +142,16 @@ def adaptive_count_loss(
|
||||
pred_probs: torch.Tensor,
|
||||
target_map: torch.Tensor,
|
||||
class_list: list = list(range(32)),
|
||||
margin_ratio: float = 0.2,
|
||||
zero_margin_scale: float = 0.2,
|
||||
lambda_entropy: float = 0.05,
|
||||
lambda_local: float = 0.1,
|
||||
grid_size: int = 8,
|
||||
margin_ratio: float = 0.1, # 降低margin比例以更严格
|
||||
zero_margin_scale: float = 0.1, # 减少零类别的margin
|
||||
lambda_entropy: float = 0.2, # 增大熵约束权重
|
||||
lambda_local: float = 0.2,
|
||||
lambda_max: float = 0, # 新增最大概率约束
|
||||
grid_size: int = 4, # 减小局部网格尺寸
|
||||
eps: float = 1e-3
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
改进版自适应图块数量约束损失,包含局部匹配和熵约束
|
||||
|
||||
改进版自适应图块数量约束损失,增强局部匹配和概率确定性
|
||||
"""
|
||||
B, C, H, W = pred_probs.shape
|
||||
device = pred_probs.device
|
||||
@ -134,13 +161,14 @@ def adaptive_count_loss(
|
||||
# 预计算地图面积
|
||||
map_area = math.sqrt(H * W)
|
||||
|
||||
# 计算最小非零类别概率
|
||||
min_nonzero_prob = pred_probs[:, class_list].max(dim=1).values.mean() # 获取预测中的最小非零概率
|
||||
dynamic_zero_margin = zero_margin_scale * min_nonzero_prob * map_area # 让零类别不被填充
|
||||
# 动态调整零类别的margin:基于预测中最小的非零概率
|
||||
min_nonzero_prob = pred_probs[:, class_list].max(dim=1).values.mean()
|
||||
dynamic_zero_margin = zero_margin_scale * min_nonzero_prob * map_area
|
||||
|
||||
# 计算每个类别的数量损失
|
||||
for cls in class_list:
|
||||
pred_count = pred_probs[:, cls].sum(dim=(1,2)) # 预测类别数量
|
||||
true_count = target_map[:, cls].sum(dim=(1,2)) # 真实类别数量
|
||||
pred_count = pred_probs[:, cls].sum(dim=(1,2)) # 预测数量
|
||||
true_count = target_map[:, cls].sum(dim=(1,2)) # 真实数量
|
||||
|
||||
zero_mask = (true_count == 0)
|
||||
dynamic_margin = torch.where(
|
||||
@ -153,37 +181,45 @@ def adaptive_count_loss(
|
||||
abs_error = torch.abs(pred_count - true_count)
|
||||
rel_error = abs_error / safe_true
|
||||
|
||||
# 调整损失函数形状,远离目标时惩罚更大
|
||||
loss_per_class = torch.where(
|
||||
abs_error <= dynamic_margin,
|
||||
(rel_error ** 2) * 0.8 + 0.2 * rel_error,
|
||||
rel_error - (0.5 * margin_ratio)
|
||||
rel_error ** 2, # 近目标时二次损失
|
||||
(rel_error - 0.5 * margin_ratio) ** 2 # 远目标时二次增长
|
||||
)
|
||||
|
||||
# 零类别使用更严格的绝对误差惩罚
|
||||
loss_per_class = torch.where(
|
||||
zero_mask,
|
||||
F.relu(abs_error - dynamic_margin) / map_area,
|
||||
F.relu(abs_error - dynamic_zero_margin) ** 2 / map_area,
|
||||
loss_per_class
|
||||
)
|
||||
|
||||
total_loss += loss_per_class.mean()
|
||||
valid_classes += 1
|
||||
|
||||
# 平均类别损失
|
||||
total_loss /= valid_classes
|
||||
total_loss /= valid_classes # 平均类别损失
|
||||
|
||||
# 加入负熵约束,防止类别均匀化
|
||||
# 改进的熵约束:每个像素的熵
|
||||
def entropy_loss(pred_probs):
|
||||
avg_probs = pred_probs.mean(dim=(2, 3))
|
||||
entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-6), dim=1)
|
||||
return entropy.mean()
|
||||
entropy_per_pixel = -torch.sum(pred_probs * torch.log(pred_probs + 1e-6), dim=1)
|
||||
return entropy_per_pixel.mean() # 所有像素的平均熵
|
||||
|
||||
total_loss += lambda_entropy * entropy_loss(pred_probs)
|
||||
|
||||
# 加入局部类别匹配
|
||||
def local_count_loss(pred_probs, target_probs, grid_size=8):
|
||||
pred_local = F.avg_pool2d(pred_probs, kernel_size=grid_size, stride=grid_size)
|
||||
target_local = F.avg_pool2d(target_probs, kernel_size=grid_size, stride=grid_size)
|
||||
return F.mse_loss(pred_local, target_local)
|
||||
# 新增最大概率约束:鼓励每个位置概率尖锐化
|
||||
max_probs = pred_probs.max(dim=1)[0] # 每个位置的最大概率
|
||||
max_loss = (1 - max_probs).mean() # 鼓励接近1
|
||||
total_loss += lambda_max * max_loss
|
||||
|
||||
# 改进局部损失:约束局部区域内的数量
|
||||
def local_count_loss(pred_probs, target_probs, grid_size):
|
||||
grid_area = grid_size ** 2
|
||||
# 计算每个grid内的预测数量
|
||||
pred_counts = F.avg_pool2d(pred_probs, grid_size, stride=grid_size) * grid_area
|
||||
target_counts = F.avg_pool2d(target_probs, grid_size, stride=grid_size) * grid_area
|
||||
# 使用L1损失更鲁棒
|
||||
return F.l1_loss(pred_counts, target_counts)
|
||||
|
||||
total_loss += lambda_local * local_count_loss(pred_probs, target_map, grid_size)
|
||||
|
||||
@ -300,16 +336,15 @@ def entrance_spatial_constraint(
|
||||
return total_loss
|
||||
|
||||
class GinkaLoss(nn.Module):
|
||||
def __init__(self, minamo: MinamoModel, weight=[0.5, 0.1, 0.1, 0.2, 0.1]):
|
||||
def __init__(self, minamo: MinamoModel, weight=[0.5, 0.2, 0.1, 0.2]):
|
||||
"""Ginka Model 损失函数部分
|
||||
|
||||
Args:
|
||||
weight (list, optional): 每一个损失函数的权重,从第 0 项开始,依次是:
|
||||
1. Minamo 相似度损失
|
||||
2. 外圈墙壁损失
|
||||
2. 图块种类损失,要求内部不出现箭头,外圈只出现箭头和墙壁,不允许出现非法图块
|
||||
3. 入口间距及存在性损失
|
||||
4. 怪物、道具、门数量损失
|
||||
5. 非法图块损失
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = weight
|
||||
@ -317,10 +352,9 @@ class GinkaLoss(nn.Module):
|
||||
|
||||
def forward(self, pred, target, target_vision_feat, target_topo_feat):
|
||||
# 地图结构损失
|
||||
border_loss = outer_border_constraint_loss(pred)
|
||||
entrance_loss = entrance_constraint_loss(pred) * 0.5 + entrance_spatial_constraint(pred) * 0.5
|
||||
class_loss = outer_border_constraint_loss(pred) + inner_constraint_loss(pred)
|
||||
entrance_loss = entrance_constraint_loss(pred) + entrance_spatial_constraint(pred)
|
||||
count_loss = adaptive_count_loss(pred, target)
|
||||
illegal_loss = illegal_tile_loss(pred)
|
||||
|
||||
# 使用 Minamo Model 计算相似度
|
||||
graph = batch_convert_soft_map_to_graph(pred)
|
||||
@ -328,23 +362,21 @@ class GinkaLoss(nn.Module):
|
||||
|
||||
vision_sim = F.cosine_similarity(pred_vision_feat, target_vision_feat, dim=-1)
|
||||
topo_sim = F.cosine_similarity(pred_topo_feat, target_topo_feat, dim=-1)
|
||||
minamo_sim = 0.3 * vision_sim + 0.7 * topo_sim
|
||||
minamo_sim = 0.2 * vision_sim + 0.8 * topo_sim
|
||||
minamo_loss = (1.0 - minamo_sim).mean()
|
||||
|
||||
print(
|
||||
minamo_loss.item(),
|
||||
border_loss.item(),
|
||||
class_loss.item(),
|
||||
entrance_loss.item(),
|
||||
count_loss.item(),
|
||||
illegal_loss.item()
|
||||
count_loss.item()
|
||||
)
|
||||
|
||||
losses = [
|
||||
minamo_loss * self.weight[0],
|
||||
border_loss * self.weight[1],
|
||||
class_loss * self.weight[1],
|
||||
entrance_loss * self.weight[2],
|
||||
count_loss * self.weight[3],
|
||||
illegal_loss * self.weight[4]
|
||||
count_loss * self.weight[3]
|
||||
]
|
||||
|
||||
# 梯度归一化
|
||||
|
||||
@ -8,12 +8,12 @@ def print_memory(tag=""):
|
||||
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
|
||||
|
||||
class GinkaModel(nn.Module):
|
||||
def __init__(self, feat_dim=1024, base_ch=64, num_classes=32):
|
||||
def __init__(self, feat_dim=1024, base_ch=64, out_ch=32):
|
||||
"""Ginka Model 模型定义部分
|
||||
"""
|
||||
super().__init__()
|
||||
self.unet = GinkaUNet(1, base_ch, num_classes, feat_dim)
|
||||
self.output = GinkaOutput(num_classes, (13, 13))
|
||||
self.unet = GinkaUNet(1, base_ch, out_ch, feat_dim)
|
||||
self.output = GinkaOutput(out_ch, out_ch, (13, 13))
|
||||
|
||||
def forward(self, x, feat):
|
||||
"""
|
||||
|
||||
@ -2,9 +2,12 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class GinkaOutput(nn.Module):
|
||||
def __init__(self, num_classes=32, out_size=(13, 13)):
|
||||
def __init__(self, out_ch=32, base_ch=64, out_size=(13, 13)):
|
||||
super().__init__()
|
||||
self.pool = nn.AdaptiveAvgPool2d(out_size)
|
||||
self.conv_down = nn.Sequential(
|
||||
nn.AdaptiveMaxPool2d(out_size),
|
||||
nn.Conv2d(base_ch, out_ch, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.pool(x)
|
||||
return self.conv_down(x)
|
||||
|
||||
@ -49,7 +49,7 @@ def train():
|
||||
|
||||
if args.resume:
|
||||
data = torch.load(args.from_state, map_location=device)
|
||||
model.load_state_dict(data["model_state"])
|
||||
model.load_state_dict(data["model_state"], strict=False)
|
||||
if args.load_optim:
|
||||
optimizer.load_state_dict(data["optimizer_state"])
|
||||
print("Train from loaded state.")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import torch.nn as nn
|
||||
|
||||
class MinamoLoss(nn.Module):
|
||||
def __init__(self, vision_weight=0.4, topo_weight=0.6):
|
||||
def __init__(self, vision_weight=0.2, topo_weight=0.8):
|
||||
super().__init__()
|
||||
self.vision_weight = vision_weight
|
||||
self.topo_weight = topo_weight
|
||||
|
||||
Loading…
Reference in New Issue
Block a user