mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-23 12:21:11 +08:00
refactor: GINKA Model 改为 softmax 输出
This commit is contained in:
parent
cb9e67dff7
commit
fd72b1e7f4
@ -1,8 +1,9 @@
|
|||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from minamo.model.model import MinamoModel
|
from minamo.model.model import MinamoModel
|
||||||
from shared.graph import convert_map_to_graph
|
from shared.graph import convert_soft_map_to_graph
|
||||||
|
|
||||||
def load_data(path: str):
|
def load_data(path: str):
|
||||||
with open(path, 'r', encoding="utf-8") as f:
|
with open(path, 'r', encoding="utf-8") as f:
|
||||||
@ -27,8 +28,8 @@ class GinkaDataset(Dataset):
|
|||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
item = self.data[idx]
|
item = self.data[idx]
|
||||||
|
|
||||||
target = torch.tensor(item["map"]).to(self.device)
|
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float().to(self.device) # [32, H, W]
|
||||||
graph = convert_map_to_graph(target).to(self.device)
|
graph = convert_soft_map_to_graph(target).to(self.device)
|
||||||
vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph)
|
vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -3,9 +3,9 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from minamo.model.model import MinamoModel
|
from minamo.model.model import MinamoModel
|
||||||
from shared.graph import DynamicGraphConverter
|
from shared.graph import convert_soft_map_to_graph
|
||||||
|
|
||||||
def wall_border_loss(pred: torch.Tensor, probs: torch.Tensor, allow_border=[1, 11]):
|
def wall_border_loss(pred: torch.Tensor, allow_border=[1, 11]):
|
||||||
"""地图最外层是否为墙"""
|
"""地图最外层是否为墙"""
|
||||||
# 计算 softmax 概率
|
# 计算 softmax 概率
|
||||||
B, C, H, W = pred.shape
|
B, C, H, W = pred.shape
|
||||||
@ -18,7 +18,7 @@ def wall_border_loss(pred: torch.Tensor, probs: torch.Tensor, allow_border=[1, 1
|
|||||||
border_mask[:, -1] = True
|
border_mask[:, -1] = True
|
||||||
|
|
||||||
# 对允许的类别求概率和(即该像素为允许类别的总概率)
|
# 对允许的类别求概率和(即该像素为允许类别的总概率)
|
||||||
allowed_prob = probs[:, allow_border, :, :].sum(dim=1) # [B, H, W]
|
allowed_prob = pred[:, allow_border, :, :].sum(dim=1) # [B, H, W]
|
||||||
|
|
||||||
# 只计算边界区域的损失:对于边界上的每个像素,要求 allowed_prob 越高越好
|
# 只计算边界区域的损失:对于边界上的每个像素,要求 allowed_prob 越高越好
|
||||||
border_allowed_prob = allowed_prob[:, border_mask] # [B, N_border_pixels]
|
border_allowed_prob = allowed_prob[:, border_mask] # [B, N_border_pixels]
|
||||||
@ -28,13 +28,13 @@ def wall_border_loss(pred: torch.Tensor, probs: torch.Tensor, allow_border=[1, 1
|
|||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def internal_wall_loss(logits, probs, wall_class=1, threshold=2.5):
|
def internal_wall_loss(pred, wall_class=1, threshold=2.5):
|
||||||
"""
|
"""
|
||||||
针对内部区域(排除最外圈)设计的损失函数:
|
针对内部区域(排除最外圈)设计的损失函数:
|
||||||
当内部任意 2×2 区域的 wall 类别概率之和超过阈值时,施加惩罚。
|
当内部任意 2×2 区域的 wall 类别概率之和超过阈值时,施加惩罚。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
logits: 模型输出,形状 [B, C, H, W]
|
pred: 模型输出,形状 [B, C, H, W]
|
||||||
wall_class: 对应墙壁的类别索引(这里假设墙壁数字为1)
|
wall_class: 对应墙壁的类别索引(这里假设墙壁数字为1)
|
||||||
threshold: 2×2 区域概率之和的阈值,超过此值时施加惩罚。可根据实际情况调节。
|
threshold: 2×2 区域概率之和的阈值,超过此值时施加惩罚。可根据实际情况调节。
|
||||||
|
|
||||||
@ -42,13 +42,13 @@ def internal_wall_loss(logits, probs, wall_class=1, threshold=2.5):
|
|||||||
loss: 内部墙壁连续区域的平均惩罚损失
|
loss: 内部墙壁连续区域的平均惩罚损失
|
||||||
"""
|
"""
|
||||||
# 取出对应墙壁类别的概率图 [B, H, W]
|
# 取出对应墙壁类别的概率图 [B, H, W]
|
||||||
wall_probs = probs[:, wall_class, :, :]
|
wall_probs = pred[:, wall_class, :, :]
|
||||||
|
|
||||||
# 排除最外圈,取内部区域 (H, W 均减去2)
|
# 排除最外圈,取内部区域 (H, W 均减去2)
|
||||||
interior = wall_probs[:, 1:-1, 1:-1] # [B, H-2, W-2]
|
interior = wall_probs[:, 1:-1, 1:-1] # [B, H-2, W-2]
|
||||||
|
|
||||||
# 构造一个 2×2 的卷积核,全为 1,用于检测局部连续墙壁的概率之和
|
# 构造一个 2×2 的卷积核,全为 1,用于检测局部连续墙壁的概率之和
|
||||||
kernel = torch.ones((1, 1, 2, 2), device=logits.device)
|
kernel = torch.ones((1, 1, 2, 2), device=pred.device)
|
||||||
|
|
||||||
# 对内部区域进行卷积操作,计算每个 2×2 区域内的概率和
|
# 对内部区域进行卷积操作,计算每个 2×2 区域内的概率和
|
||||||
# 需要将 interior 扩展一个通道维度
|
# 需要将 interior 扩展一个通道维度
|
||||||
@ -63,14 +63,14 @@ def internal_wall_loss(logits, probs, wall_class=1, threshold=2.5):
|
|||||||
loss = penalty.mean()
|
loss = penalty.mean()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def entrance_loss(logits, probs, stairs_class=10, arrow_class=11):
|
def entrance_loss(pred, stairs_class=10, arrow_class=11):
|
||||||
"""
|
"""
|
||||||
针对地图生成的额外约束损失:
|
针对地图生成的额外约束损失:
|
||||||
- 保证最外圈不出现楼梯类型入口(数字10)
|
- 保证最外圈不出现楼梯类型入口(数字10)
|
||||||
- 保证内部区域不出现箭头类型入口(数字11)
|
- 保证内部区域不出现箭头类型入口(数字11)
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
logits: 模型输出,形状 [B, C, H, W]
|
pred: 模型输出,形状 [B, C, H, W]
|
||||||
stairs_class: 楼梯入口对应的类别(数字10)
|
stairs_class: 楼梯入口对应的类别(数字10)
|
||||||
arrow_class: 箭头入口对应的类别(数字11)
|
arrow_class: 箭头入口对应的类别(数字11)
|
||||||
|
|
||||||
@ -78,10 +78,10 @@ def entrance_loss(logits, probs, stairs_class=10, arrow_class=11):
|
|||||||
loss: 针对入口出现的惩罚损失
|
loss: 针对入口出现的惩罚损失
|
||||||
"""
|
"""
|
||||||
# 先将 logits 转为概率分布
|
# 先将 logits 转为概率分布
|
||||||
B, C, H, W = logits.shape
|
B, C, H, W = pred.shape
|
||||||
|
|
||||||
# 构造最外圈 mask:外圈为 True,其余为 False
|
# 构造最外圈 mask:外圈为 True,其余为 False
|
||||||
outer_mask = torch.zeros((H, W), dtype=torch.bool, device=logits.device)
|
outer_mask = torch.zeros((H, W), dtype=torch.bool, device=pred.device)
|
||||||
outer_mask[0, :] = True
|
outer_mask[0, :] = True
|
||||||
outer_mask[-1, :] = True
|
outer_mask[-1, :] = True
|
||||||
outer_mask[:, 0] = True
|
outer_mask[:, 0] = True
|
||||||
@ -91,8 +91,8 @@ def entrance_loss(logits, probs, stairs_class=10, arrow_class=11):
|
|||||||
interior_mask = ~outer_mask # 取反
|
interior_mask = ~outer_mask # 取反
|
||||||
|
|
||||||
# 提取对应类别的概率图
|
# 提取对应类别的概率图
|
||||||
stairs_probs = probs[:, stairs_class, :, :] # 楼梯概率 [B, H, W]
|
stairs_probs = pred[:, stairs_class, :, :] # 楼梯概率 [B, H, W]
|
||||||
arrow_probs = probs[:, arrow_class, :, :] # 箭头概率 [B, H, W]
|
arrow_probs = pred[:, arrow_class, :, :] # 箭头概率 [B, H, W]
|
||||||
|
|
||||||
# 从最外圈提取楼梯概率;用 mask 索引时:张量[:, mask] 会将每个样本的外圈像素展平
|
# 从最外圈提取楼梯概率;用 mask 索引时:张量[:, mask] 会将每个样本的外圈像素展平
|
||||||
outer_stairs = stairs_probs[:, outer_mask] # [B, num_outer_pixels]
|
outer_stairs = stairs_probs[:, outer_mask] # [B, num_outer_pixels]
|
||||||
@ -107,7 +107,7 @@ def entrance_loss(logits, probs, stairs_class=10, arrow_class=11):
|
|||||||
return total_loss
|
return total_loss
|
||||||
|
|
||||||
def entrance_distance_and_presence_loss(
|
def entrance_distance_and_presence_loss(
|
||||||
logits, probs,
|
pred,
|
||||||
arrow_class=11, stairs_class=10,
|
arrow_class=11, stairs_class=10,
|
||||||
arrow_min_threshold=0.5, stairs_min_threshold=0.5,
|
arrow_min_threshold=0.5, stairs_min_threshold=0.5,
|
||||||
lambda_arrow_presence=1.0, lambda_stairs_presence=1.0
|
lambda_arrow_presence=1.0, lambda_stairs_presence=1.0
|
||||||
@ -121,7 +121,7 @@ def entrance_distance_and_presence_loss(
|
|||||||
楼梯入口要求在一个窗口(地图尺寸一半)内只出现一个楼梯入口。
|
楼梯入口要求在一个窗口(地图尺寸一半)内只出现一个楼梯入口。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
logits: 模型输出, shape [B, C, H, W]
|
pred: 模型输出, shape [B, C, H, W]
|
||||||
arrow_class: 箭头入口类别(默认 11)
|
arrow_class: 箭头入口类别(默认 11)
|
||||||
stairs_class: 楼梯入口类别(默认 10)
|
stairs_class: 楼梯入口类别(默认 10)
|
||||||
arrow_min_threshold: 箭头入口全局最小平均概率要求(可根据任务调节)
|
arrow_min_threshold: 箭头入口全局最小平均概率要求(可根据任务调节)
|
||||||
@ -132,15 +132,15 @@ def entrance_distance_and_presence_loss(
|
|||||||
total_loss: 综合入口距离与存在性损失
|
total_loss: 综合入口距离与存在性损失
|
||||||
"""
|
"""
|
||||||
# 将 logits 转换为概率分布
|
# 将 logits 转换为概率分布
|
||||||
B, C, H, W = logits.shape
|
B, C, H, W = pred.shape
|
||||||
|
|
||||||
# 提取箭头和楼梯的概率图
|
# 提取箭头和楼梯的概率图
|
||||||
arrow_probs = probs[:, arrow_class, :, :] # [B, H, W]
|
arrow_probs = pred[:, arrow_class, :, :] # [B, H, W]
|
||||||
stairs_probs = probs[:, stairs_class, :, :] # [B, H, W]
|
stairs_probs = pred[:, stairs_class, :, :] # [B, H, W]
|
||||||
|
|
||||||
#### 局部距离约束 ####
|
#### 局部距离约束 ####
|
||||||
# 箭头:构造 9x9 卷积核,半径 4
|
# 箭头:构造 9x9 卷积核,半径 4
|
||||||
kernel_arrow = torch.ones((1, 1, 9, 9), device=logits.device)
|
kernel_arrow = torch.ones((1, 1, 9, 9), device=pred.device)
|
||||||
local_arrow_sum = F.conv2d(arrow_probs.unsqueeze(1), kernel_arrow, padding=4)
|
local_arrow_sum = F.conv2d(arrow_probs.unsqueeze(1), kernel_arrow, padding=4)
|
||||||
# 减去自身概率,计算多余的局部累积
|
# 减去自身概率,计算多余的局部累积
|
||||||
arrow_excess = local_arrow_sum - arrow_probs.unsqueeze(1)
|
arrow_excess = local_arrow_sum - arrow_probs.unsqueeze(1)
|
||||||
@ -148,7 +148,7 @@ def entrance_distance_and_presence_loss(
|
|||||||
|
|
||||||
# 楼梯:使用窗口大小为 (W//2, H//2)
|
# 楼梯:使用窗口大小为 (W//2, H//2)
|
||||||
kernel_size_stairs = (9, 9)
|
kernel_size_stairs = (9, 9)
|
||||||
kernel_stairs = torch.ones((1, 1, kernel_size_stairs[0], kernel_size_stairs[1]), device=logits.device)
|
kernel_stairs = torch.ones((1, 1, kernel_size_stairs[0], kernel_size_stairs[1]), device=pred.device)
|
||||||
pad_stairs = ((kernel_size_stairs[0] - 1) // 2, (kernel_size_stairs[1] - 1) // 2)
|
pad_stairs = ((kernel_size_stairs[0] - 1) // 2, (kernel_size_stairs[1] - 1) // 2)
|
||||||
local_stairs_sum = F.conv2d(stairs_probs.unsqueeze(1), kernel_stairs, padding=pad_stairs)
|
local_stairs_sum = F.conv2d(stairs_probs.unsqueeze(1), kernel_stairs, padding=pad_stairs)
|
||||||
stairs_excess = local_stairs_sum - stairs_probs.unsqueeze(1)
|
stairs_excess = local_stairs_sum - stairs_probs.unsqueeze(1)
|
||||||
@ -175,12 +175,12 @@ def entrance_distance_and_presence_loss(
|
|||||||
+ min(ap_weighted, sp_weighted)
|
+ min(ap_weighted, sp_weighted)
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|
||||||
def monster_consecutive_loss(logits, probs, monster_classes=[7,8,9], threshold=2.9):
|
def monster_consecutive_loss(pred, monster_classes=[7,8,9], threshold=2.9):
|
||||||
"""
|
"""
|
||||||
检查横向和纵向是否存在连续超过三个的怪物(类别 7,8,9)。
|
检查横向和纵向是否存在连续超过三个的怪物(类别 7,8,9)。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
logits: 模型输出,形状 [B, C, H, W]
|
pred: 模型输出,形状 [B, C, H, W]
|
||||||
monster_classes: 待检测的怪物类别列表
|
monster_classes: 待检测的怪物类别列表
|
||||||
threshold: 滑动窗口内概率和的阈值,若超过则施加惩罚
|
threshold: 滑动窗口内概率和的阈值,若超过则施加惩罚
|
||||||
(对于连续三个像素,如果每个像素概率接近 1,则窗口和接近 3)
|
(对于连续三个像素,如果每个像素概率接近 1,则窗口和接近 3)
|
||||||
@ -189,23 +189,23 @@ def monster_consecutive_loss(logits, probs, monster_classes=[7,8,9], threshold=2
|
|||||||
loss: 惩罚损失(数值越高表示连续怪物区域越严重)
|
loss: 惩罚损失(数值越高表示连续怪物区域越严重)
|
||||||
"""
|
"""
|
||||||
# 将 logits 转换为概率分布
|
# 将 logits 转换为概率分布
|
||||||
B, C, H, W = logits.shape
|
B, C, H, W = pred.shape
|
||||||
|
|
||||||
# 得到怪物整体概率图:将类别 7,8,9 的概率相加
|
# 得到怪物整体概率图:将类别 7,8,9 的概率相加
|
||||||
monster_probs = probs[:, monster_classes, :].sum(dim=1) # [B, H, W]
|
monster_probs = pred[:, monster_classes, :].sum(dim=1) # [B, H, W]
|
||||||
|
|
||||||
# 注意:monster_probs 越高说明该像素更有可能是怪物
|
# 注意:monster_probs 越高说明该像素更有可能是怪物
|
||||||
|
|
||||||
# --- 横向检测 ---
|
# --- 横向检测 ---
|
||||||
# 构造一个 (1,3) 的卷积核,全 1
|
# 构造一个 (1,3) 的卷积核,全 1
|
||||||
kernel_horiz = torch.ones((1, 1, 1, 3), device=logits.device)
|
kernel_horiz = torch.ones((1, 1, 1, 3), device=pred.device)
|
||||||
# 对 monster_probs 加一个 channel 维度,使形状为 [B, 1, H, W]
|
# 对 monster_probs 加一个 channel 维度,使形状为 [B, 1, H, W]
|
||||||
conv_horiz = F.conv2d(monster_probs.unsqueeze(1), kernel_horiz, padding=(0,1))
|
conv_horiz = F.conv2d(monster_probs.unsqueeze(1), kernel_horiz, padding=(0,1))
|
||||||
# conv_horiz 的每个值表示相邻三个像素的怪物概率和
|
# conv_horiz 的每个值表示相邻三个像素的怪物概率和
|
||||||
|
|
||||||
# --- 纵向检测 ---
|
# --- 纵向检测 ---
|
||||||
# 构造一个 (3,1) 的卷积核,全 1
|
# 构造一个 (3,1) 的卷积核,全 1
|
||||||
kernel_vert = torch.ones((1, 1, 3, 1), device=logits.device)
|
kernel_vert = torch.ones((1, 1, 3, 1), device=pred.device)
|
||||||
conv_vert = F.conv2d(monster_probs.unsqueeze(1), kernel_vert, padding=(1,0))
|
conv_vert = F.conv2d(monster_probs.unsqueeze(1), kernel_vert, padding=(1,0))
|
||||||
# conv_vert 的每个值表示垂直连续三个像素的怪物概率和
|
# conv_vert 的每个值表示垂直连续三个像素的怪物概率和
|
||||||
|
|
||||||
@ -217,31 +217,32 @@ def monster_consecutive_loss(logits, probs, monster_classes=[7,8,9], threshold=2
|
|||||||
loss = penalty_horiz.mean() + penalty_vert.mean()
|
loss = penalty_horiz.mean() + penalty_vert.mean()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def illegal_block_loss(logits ,probs, used_classes=12, mode='mean'):
|
def illegal_block_loss(pred, used_classes=12, mode='mean'):
|
||||||
"""
|
"""
|
||||||
对未使用类别(例如 12 ~ 31)的预测概率施加惩罚,
|
对未使用类别(例如 12 ~ 31)的预测概率施加惩罚,
|
||||||
鼓励模型输出仅集中在 0 ~ 11 上。
|
鼓励模型输出仅集中在 0 ~ 11 上。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
logits: 模型输出,形状 [B, num_classes, H, W]
|
pred: 模型输出,形状 [B, num_classes, H, W]
|
||||||
used_classes: 已经使用的类别数(例如 12 表示只使用 0-11)
|
used_classes: 已经使用的类别数(例如 12 表示只使用 0-11)
|
||||||
mode: 'mean' 使用平均概率,或 'mse' 使用均方误差
|
mode: 'mean' 使用平均概率,或 'mse' 使用均方误差
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
penalty: 标量惩罚损失
|
penalty: 标量惩罚损失
|
||||||
"""
|
"""
|
||||||
|
B, C, H, W = pred.shape
|
||||||
# 选取非法类别的概率(注意:这一步会得到非法图块在每个像素上的概率)
|
# 选取非法类别的概率(注意:这一步会得到非法图块在每个像素上的概率)
|
||||||
illegal_probs = probs[:, range(used_classes, 32), :, :] # [B, len(illegal_classes), H, W]
|
illegal_probs = pred[:, range(used_classes, 32), :, :] # [B, len(illegal_classes), H, W]
|
||||||
|
|
||||||
# 我们可以将非法图块的概率在类别维度上求和,得到每个像素的非法激活值
|
# 我们可以将非法图块的概率在类别维度上求和,得到每个像素的非法激活值
|
||||||
illegal_activation = illegal_probs.sum(dim=1) # [B, H, W]
|
illegal_activation = illegal_probs.sum(dim=1) # [B, H, W]
|
||||||
|
|
||||||
# 接下来我们计算整个图上非法激活的“数量”
|
# 接下来我们计算整个图上非法激活的“数量”
|
||||||
# 例如,可以直接对整个 batch 内非法激活求和
|
# 例如,可以直接对整个 batch 内非法激活求和
|
||||||
total_illegal = illegal_activation.sum() # 标量
|
total_illegal = illegal_activation.sum() / B # 标量
|
||||||
|
|
||||||
# 计算损失值:使用负指数函数。注意如果非法激活很小,总损失接近 exp(0)=1
|
# 计算损失值:使用负指数函数。注意如果非法激活很小,总损失接近 exp(0)=1
|
||||||
loss = torch.sqrt(total_illegal)
|
loss = torch.sqrt(total_illegal).mean()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def integrated_count_loss(probs, target, class_list=[0,1,2,3,4,5,6,7,8,9], tolerance=0.5):
|
def integrated_count_loss(probs, target, class_list=[0,1,2,3,4,5,6,7,8,9], tolerance=0.5):
|
||||||
@ -283,7 +284,7 @@ def integrated_count_loss(probs, target, class_list=[0,1,2,3,4,5,6,7,8,9], toler
|
|||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
class GinkaLoss(nn.Module):
|
class GinkaLoss(nn.Module):
|
||||||
def __init__(self, minamo: MinamoModel, converter: DynamicGraphConverter, weight=[0.35, 0.1, 0.1, 0.1, 0.1, 0.05, 0.1, 0.1]):
|
def __init__(self, minamo: MinamoModel, weight=[0.35, 0.1, 0.1, 0.1, 0.1, 0.05, 0.1, 0.1]):
|
||||||
"""Ginka Model 损失函数部分
|
"""Ginka Model 损失函数部分
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -299,41 +300,37 @@ class GinkaLoss(nn.Module):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
self.ce = nn.CrossEntropyLoss()
|
|
||||||
self.minamo = minamo
|
self.minamo = minamo
|
||||||
self.tau = 1
|
|
||||||
self.converter = converter
|
|
||||||
|
|
||||||
def forward(self, pred, pred_softmax, target, target_vision_feat, target_topo_feat):
|
def forward(self, pred, target, target_vision_feat, target_topo_feat):
|
||||||
probs = F.softmax(pred, dim=1)
|
|
||||||
# 地图结构损失
|
# 地图结构损失
|
||||||
border_loss = wall_border_loss(pred, probs)
|
border_loss = wall_border_loss(pred)
|
||||||
wall_loss = internal_wall_loss(pred, probs)
|
wall_loss = internal_wall_loss(pred)
|
||||||
entry_loss = entrance_loss(pred, probs)
|
entry_loss = entrance_loss(pred)
|
||||||
entry_dis_loss = entrance_distance_and_presence_loss(pred, probs)
|
entry_dis_loss = entrance_distance_and_presence_loss(pred, )
|
||||||
enemy_loss = monster_consecutive_loss(pred, probs)
|
enemy_loss = monster_consecutive_loss(pred)
|
||||||
valid_block_loss = illegal_block_loss(pred, probs, used_classes=12, mode="mean")
|
valid_block_loss = illegal_block_loss(pred, used_classes=12, mode="mean")
|
||||||
count_loss = integrated_count_loss(probs, target)
|
count_loss = integrated_count_loss(pred, target)
|
||||||
|
|
||||||
# 使用 Minamo Model 计算相似度
|
# 使用 Minamo Model 计算相似度
|
||||||
graph = self.converter(pred, tau=self.tau)
|
graph = convert_soft_map_to_graph(pred)
|
||||||
pred_vision_feat, pred_topo_feat = self.minamo(pred_softmax, graph)
|
pred_vision_feat, pred_topo_feat = self.minamo(pred, graph)
|
||||||
|
|
||||||
vision_sim = F.cosine_similarity(pred_vision_feat, target_vision_feat, dim=-1)
|
vision_sim = F.cosine_similarity(pred_vision_feat, target_vision_feat, dim=-1)
|
||||||
topo_sim = F.cosine_similarity(pred_topo_feat, target_topo_feat, dim=-1)
|
topo_sim = F.cosine_similarity(pred_topo_feat, target_topo_feat, dim=-1)
|
||||||
minamo_sim = 0.3 * vision_sim + 0.7 * topo_sim
|
minamo_sim = 0.3 * vision_sim + 0.7 * topo_sim
|
||||||
minamo_loss = torch.exp(-1 * (minamo_sim - 0.8)).mean()
|
minamo_loss = torch.exp(-1 * (minamo_sim - 0.8)).mean()
|
||||||
|
|
||||||
# print(
|
print(
|
||||||
# minamo_loss.item(),
|
minamo_loss.item(),
|
||||||
# border_loss.item(),
|
border_loss.item(),
|
||||||
# wall_loss.item(),
|
wall_loss.item(),
|
||||||
# entry_loss.item(),
|
entry_loss.item(),
|
||||||
# entry_dis_loss.item(),
|
entry_dis_loss.item(),
|
||||||
# enemy_loss.item(),
|
enemy_loss.item(),
|
||||||
# valid_block_loss.item(),
|
valid_block_loss.item(),
|
||||||
# count_loss.item()
|
count_loss.item()
|
||||||
# )
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
minamo_loss * self.weight[0] +
|
minamo_loss * self.weight[0] +
|
||||||
|
|||||||
@ -3,20 +3,6 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from .unet import GinkaUNet
|
from .unet import GinkaUNet
|
||||||
|
|
||||||
class GumbelSoftmax(nn.Module):
|
|
||||||
def __init__(self, tau=1.0, hard=True):
|
|
||||||
super().__init__()
|
|
||||||
self.tau = tau # 温度参数
|
|
||||||
self.hard = hard # 是否生成硬性one-hot
|
|
||||||
|
|
||||||
def forward(self, logits):
|
|
||||||
# logits形状: [BS, C, H, W]
|
|
||||||
y = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
|
|
||||||
|
|
||||||
# 转换为类索引的连续表示
|
|
||||||
# class_indices = torch.arange(y.size(1), device=y.device).view(1, -1, 1, 1)
|
|
||||||
return y.argmax(dim=1) # 形状[BS, H, W]
|
|
||||||
|
|
||||||
class GinkaModel(nn.Module):
|
class GinkaModel(nn.Module):
|
||||||
def __init__(self, feat_dim=256, base_ch=64, num_classes=32):
|
def __init__(self, feat_dim=256, base_ch=64, num_classes=32):
|
||||||
"""Ginka Model 模型定义部分
|
"""Ginka Model 模型定义部分
|
||||||
@ -27,7 +13,6 @@ class GinkaModel(nn.Module):
|
|||||||
nn.Linear(feat_dim, 32 * 32 * base_ch)
|
nn.Linear(feat_dim, 32 * 32 * base_ch)
|
||||||
)
|
)
|
||||||
self.unet = GinkaUNet(base_ch, num_classes)
|
self.unet = GinkaUNet(base_ch, num_classes)
|
||||||
self.softmax = GumbelSoftmax()
|
|
||||||
|
|
||||||
def forward(self, feat):
|
def forward(self, feat):
|
||||||
"""
|
"""
|
||||||
@ -40,5 +25,5 @@ class GinkaModel(nn.Module):
|
|||||||
x = x.view(-1, self.base_ch, 32, 32)
|
x = x.view(-1, self.base_ch, 32, 32)
|
||||||
x = self.unet(x)
|
x = self.unet(x)
|
||||||
x = F.interpolate(x, (13, 13), mode='bilinear', align_corners=False)
|
x = F.interpolate(x, (13, 13), mode='bilinear', align_corners=False)
|
||||||
return x, self.softmax(x)
|
return F.softmax(x)
|
||||||
|
|
||||||
@ -8,7 +8,6 @@ from .model.model import GinkaModel
|
|||||||
from .model.loss import GinkaLoss
|
from .model.loss import GinkaLoss
|
||||||
from .dataset import GinkaDataset
|
from .dataset import GinkaDataset
|
||||||
from minamo.model.model import MinamoModel
|
from minamo.model.model import MinamoModel
|
||||||
from shared.graph import DynamicGraphConverter
|
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
os.makedirs("result", exist_ok=True)
|
os.makedirs("result", exist_ok=True)
|
||||||
@ -22,6 +21,10 @@ def update_tau(epoch):
|
|||||||
decay_rate = 0.95
|
decay_rate = 0.95
|
||||||
return max(min_tau, start_tau * (decay_rate ** epoch))
|
return max(min_tau, start_tau * (decay_rate ** epoch))
|
||||||
|
|
||||||
|
# 在生成器输出后添加梯度检查钩子
|
||||||
|
def grad_hook(module, grad_input, grad_output):
|
||||||
|
print(f"Generator output grad norm: {grad_output[0].norm().item()}")
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||||
model = GinkaModel()
|
model = GinkaModel()
|
||||||
@ -31,10 +34,8 @@ def train():
|
|||||||
minamo.to(device)
|
minamo.to(device)
|
||||||
minamo.eval()
|
minamo.eval()
|
||||||
|
|
||||||
for param in minamo.parameters():
|
# for param in minamo.parameters():
|
||||||
param.requires_grad = False
|
# param.requires_grad = False
|
||||||
|
|
||||||
converter = DynamicGraphConverter().to(device)
|
|
||||||
|
|
||||||
# 准备数据集
|
# 准备数据集
|
||||||
dataset = GinkaDataset("ginka-dataset.json", device, minamo)
|
dataset = GinkaDataset("ginka-dataset.json", device, minamo)
|
||||||
@ -53,14 +54,16 @@ def train():
|
|||||||
# 设定优化器与调度器
|
# 设定优化器与调度器
|
||||||
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
|
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
|
||||||
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
||||||
criterion = GinkaLoss(minamo, converter)
|
criterion = GinkaLoss(minamo, weight=[1, 0, 0, 0, 0, 0, 0, 0])
|
||||||
|
|
||||||
|
model.register_full_backward_hook(grad_hook)
|
||||||
|
# converter.register_full_backward_hook(grad_hook)
|
||||||
|
criterion.register_full_backward_hook(grad_hook)
|
||||||
|
|
||||||
# 开始训练
|
# 开始训练
|
||||||
for epoch in tqdm(range(epochs)):
|
for epoch in tqdm(range(epochs)):
|
||||||
model.train()
|
model.train()
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
model.softmax.tau = update_tau(epoch)
|
|
||||||
criterion.tau = update_tau(epoch)
|
|
||||||
|
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
# 数据迁移到设备
|
# 数据迁移到设备
|
||||||
@ -81,7 +84,7 @@ def train():
|
|||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
|
|
||||||
avg_loss = total_loss / len(dataloader)
|
avg_loss = total_loss / len(dataloader)
|
||||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
||||||
|
|
||||||
# total_norm = 0
|
# total_norm = 0
|
||||||
# for p in model.parameters():
|
# for p in model.parameters():
|
||||||
|
|||||||
55
ginka/validate.py
Normal file
55
ginka/validate.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch_geometric.loader import DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
|
from minamo.model.model import MinamoModel
|
||||||
|
from .dataset import GinkaDataset
|
||||||
|
from .model.loss import GinkaLoss
|
||||||
|
from .model.model import GinkaModel
|
||||||
|
from shared.graph import DynamicGraphConverter
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
def validate():
|
||||||
|
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
|
||||||
|
model = GinkaModel()
|
||||||
|
|
||||||
|
minamo = MinamoModel(32)
|
||||||
|
minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
|
||||||
|
minamo.to(device)
|
||||||
|
|
||||||
|
# 准备数据集
|
||||||
|
val_dataset = GinkaDataset("ginka-eval.json")
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=32,
|
||||||
|
shuffle=True
|
||||||
|
)
|
||||||
|
|
||||||
|
converter = DynamicGraphConverter().to(device)
|
||||||
|
criterion = GinkaLoss(minamo, converter)
|
||||||
|
|
||||||
|
minamo.eval()
|
||||||
|
val_loss = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in val_loader:
|
||||||
|
# 数据迁移到设备
|
||||||
|
target = batch["target"].to(device)
|
||||||
|
target_vision_feat = batch["target_vision_feat"].to(device)
|
||||||
|
target_topo_feat = batch["target_topo_feat"].to(device)
|
||||||
|
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
|
||||||
|
# 前向传播
|
||||||
|
output, _ = model(feat_vec)
|
||||||
|
map_matrix = torch.argmax(output, dim=1)
|
||||||
|
|
||||||
|
# 计算损失
|
||||||
|
loss = criterion(output, map_matrix, target, target_vision_feat, target_topo_feat)
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
avg_val_loss = val_loss / len(val_loader)
|
||||||
|
tqdm.write(f"Validation::loss: {avg_val_loss:.6f}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
torch.set_num_threads(2)
|
||||||
|
validate()
|
||||||
|
|
||||||
@ -122,7 +122,8 @@ def train():
|
|||||||
graph1 = graph1.to(device)
|
graph1 = graph1.to(device)
|
||||||
graph2 = graph2.to(device)
|
graph2 = graph2.to(device)
|
||||||
|
|
||||||
vision_feat1, vision_feat2, topo_feat1, topo_feat2 = model(map1_val, map2_val, graph1, graph2)
|
vision_feat1, topo_feat1 = model(map1, graph1)
|
||||||
|
vision_feat2, topo_feat2 = model(map2, graph2)
|
||||||
|
|
||||||
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
||||||
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||||
|
|||||||
@ -1,9 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
from torch_geometric.data import Data
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch_geometric.data import Data, Batch
|
|
||||||
|
|
||||||
def convert_soft_map_to_graph(map_probs):
|
def convert_soft_map_to_graph(map_probs: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
直接使用 Softmax 概率构建 soft 图结构
|
直接使用 Softmax 概率构建 soft 图结构
|
||||||
"""
|
"""
|
||||||
@ -33,87 +31,3 @@ def convert_soft_map_to_graph(map_probs):
|
|||||||
map_probs[:, edge_index[1] // W, edge_index[1] % W]) / 2
|
map_probs[:, edge_index[1] // W, edge_index[1] % W]) / 2
|
||||||
|
|
||||||
return Data(x=node_features, edge_index=edge_index, edge_attr=soft_edge_weight)
|
return Data(x=node_features, edge_index=edge_index, edge_attr=soft_edge_weight)
|
||||||
|
|
||||||
def convert_map_to_graph(map):
|
|
||||||
rows = len(map)
|
|
||||||
cols = len(map[0])
|
|
||||||
node_indices = {}
|
|
||||||
valid_nodes = []
|
|
||||||
node_counter = 0
|
|
||||||
|
|
||||||
for r in range(rows):
|
|
||||||
for c in range(cols):
|
|
||||||
if map[r][c] != 1: # 排除墙体
|
|
||||||
node_indices[(r, c)] = node_counter
|
|
||||||
valid_nodes.append((r, c, map[r][c])) # (行, 列, 地形类型)
|
|
||||||
node_counter += 1
|
|
||||||
|
|
||||||
edge_list = []
|
|
||||||
for (r, c, _) in valid_nodes:
|
|
||||||
node = node_indices[(r, c)]
|
|
||||||
if c + 1 < cols and (r, c + 1) in node_indices:
|
|
||||||
edge_list.append((node, node_indices[(r, c + 1)]))
|
|
||||||
if r + 1 < rows and (r + 1, c) in node_indices:
|
|
||||||
edge_list.append((node, node_indices[(r + 1, c)]))
|
|
||||||
|
|
||||||
edge_index = torch.tensor(edge_list, dtype=torch.long).T
|
|
||||||
node_features = torch.tensor([node_type for (_, _, node_type) in valid_nodes], dtype=torch.long)
|
|
||||||
|
|
||||||
return Data(x=node_features, edge_index=edge_index)
|
|
||||||
|
|
||||||
class DynamicGraphConverter(nn.Module):
|
|
||||||
def __init__(self, map_size=13):
|
|
||||||
super().__init__()
|
|
||||||
self.map_size = map_size
|
|
||||||
self.n_nodes = map_size * map_size
|
|
||||||
self.base_edge_index = self._precompute_base_edges()
|
|
||||||
|
|
||||||
def _precompute_base_edges(self):
|
|
||||||
edge_list = []
|
|
||||||
directions = [(0, 1), (1, 0)]
|
|
||||||
for r in range(self.map_size):
|
|
||||||
for c in range(self.map_size):
|
|
||||||
node = r * self.map_size + c
|
|
||||||
for dr, dc in directions:
|
|
||||||
nr, nc = r + dr, c + dc
|
|
||||||
if 0 <= nr < self.map_size and 0 <= nc < self.map_size:
|
|
||||||
neighbor = nr * self.map_size + nc
|
|
||||||
edge_list.append([node, neighbor])
|
|
||||||
return torch.tensor(edge_list).t().contiguous().unique(dim=1)
|
|
||||||
|
|
||||||
def forward(self, map_probs, tau=0.5):
|
|
||||||
B, C, H, W = map_probs.shape
|
|
||||||
device = map_probs.device
|
|
||||||
self.base_edge_index = self.base_edge_index.to(device)
|
|
||||||
|
|
||||||
# 1. 计算可微的节点 ID
|
|
||||||
node_logits = map_probs.view(B, C, -1).permute(0, 2, 1) # [B, N, C]
|
|
||||||
hard_nodes = F.gumbel_softmax(node_logits, tau=tau, hard=True)
|
|
||||||
node_ids = (hard_nodes * torch.arange(C, device=device).view(1, 1, -1)).sum(dim=-1).long()
|
|
||||||
|
|
||||||
# 2. 计算 soft 壁障 mask
|
|
||||||
wall_mask = torch.sigmoid((node_ids - 1) * 10) # 类别 1 代表墙体,soft 处理
|
|
||||||
edge_weights = self._compute_dynamic_weights(wall_mask)
|
|
||||||
|
|
||||||
# 3. 构建动态图
|
|
||||||
batch_data = []
|
|
||||||
for b in range(B):
|
|
||||||
soft_mask = torch.sigmoid((edge_weights[b] - 0.1) * 10) # 软门控
|
|
||||||
dynamic_edge_attr = edge_weights[b] * soft_mask # 仍然保留梯度
|
|
||||||
|
|
||||||
data = Data(
|
|
||||||
x=node_ids[b],
|
|
||||||
edge_index=self.base_edge_index,
|
|
||||||
edge_attr=dynamic_edge_attr
|
|
||||||
)
|
|
||||||
batch_data.append(data)
|
|
||||||
|
|
||||||
return Batch.from_data_list(batch_data)
|
|
||||||
|
|
||||||
def _compute_dynamic_weights(self, wall_mask):
|
|
||||||
src_nodes = self.base_edge_index[0]
|
|
||||||
dst_nodes = self.base_edge_index[1]
|
|
||||||
|
|
||||||
# 让梯度能正确回传
|
|
||||||
weights = 1 - (wall_mask[:, src_nodes] + wall_mask[:, dst_nodes]) / 2
|
|
||||||
return weights.unsqueeze(-1)
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user