refactor: GINKA Model 改为 softmax 输出

This commit is contained in:
unanmed 2025-03-19 21:24:29 +08:00
parent cb9e67dff7
commit fd72b1e7f4
7 changed files with 129 additions and 173 deletions

View File

@ -1,8 +1,9 @@
import json import json
import torch import torch
import torch.nn.functional as F
from torch.utils.data import Dataset from torch.utils.data import Dataset
from minamo.model.model import MinamoModel from minamo.model.model import MinamoModel
from shared.graph import convert_map_to_graph from shared.graph import convert_soft_map_to_graph
def load_data(path: str): def load_data(path: str):
with open(path, 'r', encoding="utf-8") as f: with open(path, 'r', encoding="utf-8") as f:
@ -27,8 +28,8 @@ class GinkaDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
item = self.data[idx] item = self.data[idx]
target = torch.tensor(item["map"]).to(self.device) target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float().to(self.device) # [32, H, W]
graph = convert_map_to_graph(target).to(self.device) graph = convert_soft_map_to_graph(target).to(self.device)
vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph) vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph)
return { return {

View File

@ -3,9 +3,9 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from minamo.model.model import MinamoModel from minamo.model.model import MinamoModel
from shared.graph import DynamicGraphConverter from shared.graph import convert_soft_map_to_graph
def wall_border_loss(pred: torch.Tensor, probs: torch.Tensor, allow_border=[1, 11]): def wall_border_loss(pred: torch.Tensor, allow_border=[1, 11]):
"""地图最外层是否为墙""" """地图最外层是否为墙"""
# 计算 softmax 概率 # 计算 softmax 概率
B, C, H, W = pred.shape B, C, H, W = pred.shape
@ -18,7 +18,7 @@ def wall_border_loss(pred: torch.Tensor, probs: torch.Tensor, allow_border=[1, 1
border_mask[:, -1] = True border_mask[:, -1] = True
# 对允许的类别求概率和(即该像素为允许类别的总概率) # 对允许的类别求概率和(即该像素为允许类别的总概率)
allowed_prob = probs[:, allow_border, :, :].sum(dim=1) # [B, H, W] allowed_prob = pred[:, allow_border, :, :].sum(dim=1) # [B, H, W]
# 只计算边界区域的损失:对于边界上的每个像素,要求 allowed_prob 越高越好 # 只计算边界区域的损失:对于边界上的每个像素,要求 allowed_prob 越高越好
border_allowed_prob = allowed_prob[:, border_mask] # [B, N_border_pixels] border_allowed_prob = allowed_prob[:, border_mask] # [B, N_border_pixels]
@ -28,13 +28,13 @@ def wall_border_loss(pred: torch.Tensor, probs: torch.Tensor, allow_border=[1, 1
return loss return loss
def internal_wall_loss(logits, probs, wall_class=1, threshold=2.5): def internal_wall_loss(pred, wall_class=1, threshold=2.5):
""" """
针对内部区域排除最外圈设计的损失函数 针对内部区域排除最外圈设计的损失函数
当内部任意 2×2 区域的 wall 类别概率之和超过阈值时施加惩罚 当内部任意 2×2 区域的 wall 类别概率之和超过阈值时施加惩罚
参数: 参数:
logits: 模型输出形状 [B, C, H, W] pred: 模型输出形状 [B, C, H, W]
wall_class: 对应墙壁的类别索引这里假设墙壁数字为1 wall_class: 对应墙壁的类别索引这里假设墙壁数字为1
threshold: 2×2 区域概率之和的阈值超过此值时施加惩罚可根据实际情况调节 threshold: 2×2 区域概率之和的阈值超过此值时施加惩罚可根据实际情况调节
@ -42,13 +42,13 @@ def internal_wall_loss(logits, probs, wall_class=1, threshold=2.5):
loss: 内部墙壁连续区域的平均惩罚损失 loss: 内部墙壁连续区域的平均惩罚损失
""" """
# 取出对应墙壁类别的概率图 [B, H, W] # 取出对应墙壁类别的概率图 [B, H, W]
wall_probs = probs[:, wall_class, :, :] wall_probs = pred[:, wall_class, :, :]
# 排除最外圈,取内部区域 (H, W 均减去2) # 排除最外圈,取内部区域 (H, W 均减去2)
interior = wall_probs[:, 1:-1, 1:-1] # [B, H-2, W-2] interior = wall_probs[:, 1:-1, 1:-1] # [B, H-2, W-2]
# 构造一个 2×2 的卷积核,全为 1用于检测局部连续墙壁的概率之和 # 构造一个 2×2 的卷积核,全为 1用于检测局部连续墙壁的概率之和
kernel = torch.ones((1, 1, 2, 2), device=logits.device) kernel = torch.ones((1, 1, 2, 2), device=pred.device)
# 对内部区域进行卷积操作,计算每个 2×2 区域内的概率和 # 对内部区域进行卷积操作,计算每个 2×2 区域内的概率和
# 需要将 interior 扩展一个通道维度 # 需要将 interior 扩展一个通道维度
@ -63,14 +63,14 @@ def internal_wall_loss(logits, probs, wall_class=1, threshold=2.5):
loss = penalty.mean() loss = penalty.mean()
return loss return loss
def entrance_loss(logits, probs, stairs_class=10, arrow_class=11): def entrance_loss(pred, stairs_class=10, arrow_class=11):
""" """
针对地图生成的额外约束损失 针对地图生成的额外约束损失
- 保证最外圈不出现楼梯类型入口数字10 - 保证最外圈不出现楼梯类型入口数字10
- 保证内部区域不出现箭头类型入口数字11 - 保证内部区域不出现箭头类型入口数字11
参数: 参数:
logits: 模型输出形状 [B, C, H, W] pred: 模型输出形状 [B, C, H, W]
stairs_class: 楼梯入口对应的类别数字10 stairs_class: 楼梯入口对应的类别数字10
arrow_class: 箭头入口对应的类别数字11 arrow_class: 箭头入口对应的类别数字11
@ -78,10 +78,10 @@ def entrance_loss(logits, probs, stairs_class=10, arrow_class=11):
loss: 针对入口出现的惩罚损失 loss: 针对入口出现的惩罚损失
""" """
# 先将 logits 转为概率分布 # 先将 logits 转为概率分布
B, C, H, W = logits.shape B, C, H, W = pred.shape
# 构造最外圈 mask外圈为 True其余为 False # 构造最外圈 mask外圈为 True其余为 False
outer_mask = torch.zeros((H, W), dtype=torch.bool, device=logits.device) outer_mask = torch.zeros((H, W), dtype=torch.bool, device=pred.device)
outer_mask[0, :] = True outer_mask[0, :] = True
outer_mask[-1, :] = True outer_mask[-1, :] = True
outer_mask[:, 0] = True outer_mask[:, 0] = True
@ -91,8 +91,8 @@ def entrance_loss(logits, probs, stairs_class=10, arrow_class=11):
interior_mask = ~outer_mask # 取反 interior_mask = ~outer_mask # 取反
# 提取对应类别的概率图 # 提取对应类别的概率图
stairs_probs = probs[:, stairs_class, :, :] # 楼梯概率 [B, H, W] stairs_probs = pred[:, stairs_class, :, :] # 楼梯概率 [B, H, W]
arrow_probs = probs[:, arrow_class, :, :] # 箭头概率 [B, H, W] arrow_probs = pred[:, arrow_class, :, :] # 箭头概率 [B, H, W]
# 从最外圈提取楼梯概率;用 mask 索引时:张量[:, mask] 会将每个样本的外圈像素展平 # 从最外圈提取楼梯概率;用 mask 索引时:张量[:, mask] 会将每个样本的外圈像素展平
outer_stairs = stairs_probs[:, outer_mask] # [B, num_outer_pixels] outer_stairs = stairs_probs[:, outer_mask] # [B, num_outer_pixels]
@ -107,7 +107,7 @@ def entrance_loss(logits, probs, stairs_class=10, arrow_class=11):
return total_loss return total_loss
def entrance_distance_and_presence_loss( def entrance_distance_and_presence_loss(
logits, probs, pred,
arrow_class=11, stairs_class=10, arrow_class=11, stairs_class=10,
arrow_min_threshold=0.5, stairs_min_threshold=0.5, arrow_min_threshold=0.5, stairs_min_threshold=0.5,
lambda_arrow_presence=1.0, lambda_stairs_presence=1.0 lambda_arrow_presence=1.0, lambda_stairs_presence=1.0
@ -121,7 +121,7 @@ def entrance_distance_and_presence_loss(
楼梯入口要求在一个窗口地图尺寸一半内只出现一个楼梯入口 楼梯入口要求在一个窗口地图尺寸一半内只出现一个楼梯入口
参数: 参数:
logits: 模型输出, shape [B, C, H, W] pred: 模型输出, shape [B, C, H, W]
arrow_class: 箭头入口类别默认 11 arrow_class: 箭头入口类别默认 11
stairs_class: 楼梯入口类别默认 10 stairs_class: 楼梯入口类别默认 10
arrow_min_threshold: 箭头入口全局最小平均概率要求可根据任务调节 arrow_min_threshold: 箭头入口全局最小平均概率要求可根据任务调节
@ -132,15 +132,15 @@ def entrance_distance_and_presence_loss(
total_loss: 综合入口距离与存在性损失 total_loss: 综合入口距离与存在性损失
""" """
# 将 logits 转换为概率分布 # 将 logits 转换为概率分布
B, C, H, W = logits.shape B, C, H, W = pred.shape
# 提取箭头和楼梯的概率图 # 提取箭头和楼梯的概率图
arrow_probs = probs[:, arrow_class, :, :] # [B, H, W] arrow_probs = pred[:, arrow_class, :, :] # [B, H, W]
stairs_probs = probs[:, stairs_class, :, :] # [B, H, W] stairs_probs = pred[:, stairs_class, :, :] # [B, H, W]
#### 局部距离约束 #### #### 局部距离约束 ####
# 箭头:构造 9x9 卷积核,半径 4 # 箭头:构造 9x9 卷积核,半径 4
kernel_arrow = torch.ones((1, 1, 9, 9), device=logits.device) kernel_arrow = torch.ones((1, 1, 9, 9), device=pred.device)
local_arrow_sum = F.conv2d(arrow_probs.unsqueeze(1), kernel_arrow, padding=4) local_arrow_sum = F.conv2d(arrow_probs.unsqueeze(1), kernel_arrow, padding=4)
# 减去自身概率,计算多余的局部累积 # 减去自身概率,计算多余的局部累积
arrow_excess = local_arrow_sum - arrow_probs.unsqueeze(1) arrow_excess = local_arrow_sum - arrow_probs.unsqueeze(1)
@ -148,7 +148,7 @@ def entrance_distance_and_presence_loss(
# 楼梯:使用窗口大小为 (W//2, H//2) # 楼梯:使用窗口大小为 (W//2, H//2)
kernel_size_stairs = (9, 9) kernel_size_stairs = (9, 9)
kernel_stairs = torch.ones((1, 1, kernel_size_stairs[0], kernel_size_stairs[1]), device=logits.device) kernel_stairs = torch.ones((1, 1, kernel_size_stairs[0], kernel_size_stairs[1]), device=pred.device)
pad_stairs = ((kernel_size_stairs[0] - 1) // 2, (kernel_size_stairs[1] - 1) // 2) pad_stairs = ((kernel_size_stairs[0] - 1) // 2, (kernel_size_stairs[1] - 1) // 2)
local_stairs_sum = F.conv2d(stairs_probs.unsqueeze(1), kernel_stairs, padding=pad_stairs) local_stairs_sum = F.conv2d(stairs_probs.unsqueeze(1), kernel_stairs, padding=pad_stairs)
stairs_excess = local_stairs_sum - stairs_probs.unsqueeze(1) stairs_excess = local_stairs_sum - stairs_probs.unsqueeze(1)
@ -175,12 +175,12 @@ def entrance_distance_and_presence_loss(
+ min(ap_weighted, sp_weighted) + min(ap_weighted, sp_weighted)
return total_loss return total_loss
def monster_consecutive_loss(logits, probs, monster_classes=[7,8,9], threshold=2.9): def monster_consecutive_loss(pred, monster_classes=[7,8,9], threshold=2.9):
""" """
检查横向和纵向是否存在连续超过三个的怪物类别 7,8,9 检查横向和纵向是否存在连续超过三个的怪物类别 7,8,9
参数: 参数:
logits: 模型输出形状 [B, C, H, W] pred: 模型输出形状 [B, C, H, W]
monster_classes: 待检测的怪物类别列表 monster_classes: 待检测的怪物类别列表
threshold: 滑动窗口内概率和的阈值若超过则施加惩罚 threshold: 滑动窗口内概率和的阈值若超过则施加惩罚
对于连续三个像素如果每个像素概率接近 1则窗口和接近 3 对于连续三个像素如果每个像素概率接近 1则窗口和接近 3
@ -189,23 +189,23 @@ def monster_consecutive_loss(logits, probs, monster_classes=[7,8,9], threshold=2
loss: 惩罚损失数值越高表示连续怪物区域越严重 loss: 惩罚损失数值越高表示连续怪物区域越严重
""" """
# 将 logits 转换为概率分布 # 将 logits 转换为概率分布
B, C, H, W = logits.shape B, C, H, W = pred.shape
# 得到怪物整体概率图:将类别 7,8,9 的概率相加 # 得到怪物整体概率图:将类别 7,8,9 的概率相加
monster_probs = probs[:, monster_classes, :].sum(dim=1) # [B, H, W] monster_probs = pred[:, monster_classes, :].sum(dim=1) # [B, H, W]
# 注意monster_probs 越高说明该像素更有可能是怪物 # 注意monster_probs 越高说明该像素更有可能是怪物
# --- 横向检测 --- # --- 横向检测 ---
# 构造一个 (1,3) 的卷积核,全 1 # 构造一个 (1,3) 的卷积核,全 1
kernel_horiz = torch.ones((1, 1, 1, 3), device=logits.device) kernel_horiz = torch.ones((1, 1, 1, 3), device=pred.device)
# 对 monster_probs 加一个 channel 维度,使形状为 [B, 1, H, W] # 对 monster_probs 加一个 channel 维度,使形状为 [B, 1, H, W]
conv_horiz = F.conv2d(monster_probs.unsqueeze(1), kernel_horiz, padding=(0,1)) conv_horiz = F.conv2d(monster_probs.unsqueeze(1), kernel_horiz, padding=(0,1))
# conv_horiz 的每个值表示相邻三个像素的怪物概率和 # conv_horiz 的每个值表示相邻三个像素的怪物概率和
# --- 纵向检测 --- # --- 纵向检测 ---
# 构造一个 (3,1) 的卷积核,全 1 # 构造一个 (3,1) 的卷积核,全 1
kernel_vert = torch.ones((1, 1, 3, 1), device=logits.device) kernel_vert = torch.ones((1, 1, 3, 1), device=pred.device)
conv_vert = F.conv2d(monster_probs.unsqueeze(1), kernel_vert, padding=(1,0)) conv_vert = F.conv2d(monster_probs.unsqueeze(1), kernel_vert, padding=(1,0))
# conv_vert 的每个值表示垂直连续三个像素的怪物概率和 # conv_vert 的每个值表示垂直连续三个像素的怪物概率和
@ -217,31 +217,32 @@ def monster_consecutive_loss(logits, probs, monster_classes=[7,8,9], threshold=2
loss = penalty_horiz.mean() + penalty_vert.mean() loss = penalty_horiz.mean() + penalty_vert.mean()
return loss return loss
def illegal_block_loss(logits ,probs, used_classes=12, mode='mean'): def illegal_block_loss(pred, used_classes=12, mode='mean'):
""" """
对未使用类别例如 12 ~ 31的预测概率施加惩罚 对未使用类别例如 12 ~ 31的预测概率施加惩罚
鼓励模型输出仅集中在 0 ~ 11 鼓励模型输出仅集中在 0 ~ 11
参数: 参数:
logits: 模型输出形状 [B, num_classes, H, W] pred: 模型输出形状 [B, num_classes, H, W]
used_classes: 已经使用的类别数例如 12 表示只使用 0-11 used_classes: 已经使用的类别数例如 12 表示只使用 0-11
mode: 'mean' 使用平均概率 'mse' 使用均方误差 mode: 'mean' 使用平均概率 'mse' 使用均方误差
返回: 返回:
penalty: 标量惩罚损失 penalty: 标量惩罚损失
""" """
B, C, H, W = pred.shape
# 选取非法类别的概率(注意:这一步会得到非法图块在每个像素上的概率) # 选取非法类别的概率(注意:这一步会得到非法图块在每个像素上的概率)
illegal_probs = probs[:, range(used_classes, 32), :, :] # [B, len(illegal_classes), H, W] illegal_probs = pred[:, range(used_classes, 32), :, :] # [B, len(illegal_classes), H, W]
# 我们可以将非法图块的概率在类别维度上求和,得到每个像素的非法激活值 # 我们可以将非法图块的概率在类别维度上求和,得到每个像素的非法激活值
illegal_activation = illegal_probs.sum(dim=1) # [B, H, W] illegal_activation = illegal_probs.sum(dim=1) # [B, H, W]
# 接下来我们计算整个图上非法激活的“数量” # 接下来我们计算整个图上非法激活的“数量”
# 例如,可以直接对整个 batch 内非法激活求和 # 例如,可以直接对整个 batch 内非法激活求和
total_illegal = illegal_activation.sum() # 标量 total_illegal = illegal_activation.sum() / B # 标量
# 计算损失值:使用负指数函数。注意如果非法激活很小,总损失接近 exp(0)=1 # 计算损失值:使用负指数函数。注意如果非法激活很小,总损失接近 exp(0)=1
loss = torch.sqrt(total_illegal) loss = torch.sqrt(total_illegal).mean()
return loss return loss
def integrated_count_loss(probs, target, class_list=[0,1,2,3,4,5,6,7,8,9], tolerance=0.5): def integrated_count_loss(probs, target, class_list=[0,1,2,3,4,5,6,7,8,9], tolerance=0.5):
@ -283,7 +284,7 @@ def integrated_count_loss(probs, target, class_list=[0,1,2,3,4,5,6,7,8,9], toler
return avg_loss return avg_loss
class GinkaLoss(nn.Module): class GinkaLoss(nn.Module):
def __init__(self, minamo: MinamoModel, converter: DynamicGraphConverter, weight=[0.35, 0.1, 0.1, 0.1, 0.1, 0.05, 0.1, 0.1]): def __init__(self, minamo: MinamoModel, weight=[0.35, 0.1, 0.1, 0.1, 0.1, 0.05, 0.1, 0.1]):
"""Ginka Model 损失函数部分 """Ginka Model 损失函数部分
Args: Args:
@ -299,41 +300,37 @@ class GinkaLoss(nn.Module):
""" """
super().__init__() super().__init__()
self.weight = weight self.weight = weight
self.ce = nn.CrossEntropyLoss()
self.minamo = minamo self.minamo = minamo
self.tau = 1
self.converter = converter
def forward(self, pred, pred_softmax, target, target_vision_feat, target_topo_feat): def forward(self, pred, target, target_vision_feat, target_topo_feat):
probs = F.softmax(pred, dim=1)
# 地图结构损失 # 地图结构损失
border_loss = wall_border_loss(pred, probs) border_loss = wall_border_loss(pred)
wall_loss = internal_wall_loss(pred, probs) wall_loss = internal_wall_loss(pred)
entry_loss = entrance_loss(pred, probs) entry_loss = entrance_loss(pred)
entry_dis_loss = entrance_distance_and_presence_loss(pred, probs) entry_dis_loss = entrance_distance_and_presence_loss(pred, )
enemy_loss = monster_consecutive_loss(pred, probs) enemy_loss = monster_consecutive_loss(pred)
valid_block_loss = illegal_block_loss(pred, probs, used_classes=12, mode="mean") valid_block_loss = illegal_block_loss(pred, used_classes=12, mode="mean")
count_loss = integrated_count_loss(probs, target) count_loss = integrated_count_loss(pred, target)
# 使用 Minamo Model 计算相似度 # 使用 Minamo Model 计算相似度
graph = self.converter(pred, tau=self.tau) graph = convert_soft_map_to_graph(pred)
pred_vision_feat, pred_topo_feat = self.minamo(pred_softmax, graph) pred_vision_feat, pred_topo_feat = self.minamo(pred, graph)
vision_sim = F.cosine_similarity(pred_vision_feat, target_vision_feat, dim=-1) vision_sim = F.cosine_similarity(pred_vision_feat, target_vision_feat, dim=-1)
topo_sim = F.cosine_similarity(pred_topo_feat, target_topo_feat, dim=-1) topo_sim = F.cosine_similarity(pred_topo_feat, target_topo_feat, dim=-1)
minamo_sim = 0.3 * vision_sim + 0.7 * topo_sim minamo_sim = 0.3 * vision_sim + 0.7 * topo_sim
minamo_loss = torch.exp(-1 * (minamo_sim - 0.8)).mean() minamo_loss = torch.exp(-1 * (minamo_sim - 0.8)).mean()
# print( print(
# minamo_loss.item(), minamo_loss.item(),
# border_loss.item(), border_loss.item(),
# wall_loss.item(), wall_loss.item(),
# entry_loss.item(), entry_loss.item(),
# entry_dis_loss.item(), entry_dis_loss.item(),
# enemy_loss.item(), enemy_loss.item(),
# valid_block_loss.item(), valid_block_loss.item(),
# count_loss.item() count_loss.item()
# ) )
return ( return (
minamo_loss * self.weight[0] + minamo_loss * self.weight[0] +

View File

@ -3,20 +3,6 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .unet import GinkaUNet from .unet import GinkaUNet
class GumbelSoftmax(nn.Module):
def __init__(self, tau=1.0, hard=True):
super().__init__()
self.tau = tau # 温度参数
self.hard = hard # 是否生成硬性one-hot
def forward(self, logits):
# logits形状: [BS, C, H, W]
y = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
# 转换为类索引的连续表示
# class_indices = torch.arange(y.size(1), device=y.device).view(1, -1, 1, 1)
return y.argmax(dim=1) # 形状[BS, H, W]
class GinkaModel(nn.Module): class GinkaModel(nn.Module):
def __init__(self, feat_dim=256, base_ch=64, num_classes=32): def __init__(self, feat_dim=256, base_ch=64, num_classes=32):
"""Ginka Model 模型定义部分 """Ginka Model 模型定义部分
@ -27,7 +13,6 @@ class GinkaModel(nn.Module):
nn.Linear(feat_dim, 32 * 32 * base_ch) nn.Linear(feat_dim, 32 * 32 * base_ch)
) )
self.unet = GinkaUNet(base_ch, num_classes) self.unet = GinkaUNet(base_ch, num_classes)
self.softmax = GumbelSoftmax()
def forward(self, feat): def forward(self, feat):
""" """
@ -40,5 +25,5 @@ class GinkaModel(nn.Module):
x = x.view(-1, self.base_ch, 32, 32) x = x.view(-1, self.base_ch, 32, 32)
x = self.unet(x) x = self.unet(x)
x = F.interpolate(x, (13, 13), mode='bilinear', align_corners=False) x = F.interpolate(x, (13, 13), mode='bilinear', align_corners=False)
return x, self.softmax(x) return F.softmax(x)

View File

@ -8,7 +8,6 @@ from .model.model import GinkaModel
from .model.loss import GinkaLoss from .model.loss import GinkaLoss
from .dataset import GinkaDataset from .dataset import GinkaDataset
from minamo.model.model import MinamoModel from minamo.model.model import MinamoModel
from shared.graph import DynamicGraphConverter
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True) os.makedirs("result", exist_ok=True)
@ -22,6 +21,10 @@ def update_tau(epoch):
decay_rate = 0.95 decay_rate = 0.95
return max(min_tau, start_tau * (decay_rate ** epoch)) return max(min_tau, start_tau * (decay_rate ** epoch))
# 在生成器输出后添加梯度检查钩子
def grad_hook(module, grad_input, grad_output):
print(f"Generator output grad norm: {grad_output[0].norm().item()}")
def train(): def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
model = GinkaModel() model = GinkaModel()
@ -31,10 +34,8 @@ def train():
minamo.to(device) minamo.to(device)
minamo.eval() minamo.eval()
for param in minamo.parameters(): # for param in minamo.parameters():
param.requires_grad = False # param.requires_grad = False
converter = DynamicGraphConverter().to(device)
# 准备数据集 # 准备数据集
dataset = GinkaDataset("ginka-dataset.json", device, minamo) dataset = GinkaDataset("ginka-dataset.json", device, minamo)
@ -53,14 +54,16 @@ def train():
# 设定优化器与调度器 # 设定优化器与调度器
optimizer = optim.AdamW(model.parameters(), lr=3e-4) optimizer = optim.AdamW(model.parameters(), lr=3e-4)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6) scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
criterion = GinkaLoss(minamo, converter) criterion = GinkaLoss(minamo, weight=[1, 0, 0, 0, 0, 0, 0, 0])
model.register_full_backward_hook(grad_hook)
# converter.register_full_backward_hook(grad_hook)
criterion.register_full_backward_hook(grad_hook)
# 开始训练 # 开始训练
for epoch in tqdm(range(epochs)): for epoch in tqdm(range(epochs)):
model.train() model.train()
total_loss = 0 total_loss = 0
model.softmax.tau = update_tau(epoch)
criterion.tau = update_tau(epoch)
for batch in dataloader: for batch in dataloader:
# 数据迁移到设备 # 数据迁移到设备
@ -81,7 +84,7 @@ def train():
total_loss += loss.item() total_loss += loss.item()
avg_loss = total_loss / len(dataloader) avg_loss = total_loss / len(dataloader)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}") tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
# total_norm = 0 # total_norm = 0
# for p in model.parameters(): # for p in model.parameters():

55
ginka/validate.py Normal file
View File

@ -0,0 +1,55 @@
import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from minamo.model.model import MinamoModel
from .dataset import GinkaDataset
from .model.loss import GinkaLoss
from .model.model import GinkaModel
from shared.graph import DynamicGraphConverter
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def validate():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
model = GinkaModel()
minamo = MinamoModel(32)
minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
minamo.to(device)
# 准备数据集
val_dataset = GinkaDataset("ginka-eval.json")
val_loader = DataLoader(
val_dataset,
batch_size=32,
shuffle=True
)
converter = DynamicGraphConverter().to(device)
criterion = GinkaLoss(minamo, converter)
minamo.eval()
val_loss = 0
with torch.no_grad():
for batch in val_loader:
# 数据迁移到设备
target = batch["target"].to(device)
target_vision_feat = batch["target_vision_feat"].to(device)
target_topo_feat = batch["target_topo_feat"].to(device)
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
# 前向传播
output, _ = model(feat_vec)
map_matrix = torch.argmax(output, dim=1)
# 计算损失
loss = criterion(output, map_matrix, target, target_vision_feat, target_topo_feat)
total_loss += loss.item()
avg_val_loss = val_loss / len(val_loader)
tqdm.write(f"Validation::loss: {avg_val_loss:.6f}")
if __name__ == "__main__":
torch.set_num_threads(2)
validate()

View File

@ -122,7 +122,8 @@ def train():
graph1 = graph1.to(device) graph1 = graph1.to(device)
graph2 = graph2.to(device) graph2 = graph2.to(device)
vision_feat1, vision_feat2, topo_feat1, topo_feat2 = model(map1_val, map2_val, graph1, graph2) vision_feat1, topo_feat1 = model(map1, graph1)
vision_feat2, topo_feat2 = model(map2, graph2)
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1) vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1) topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)

View File

@ -1,9 +1,7 @@
import torch import torch
import torch.nn as nn from torch_geometric.data import Data
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
def convert_soft_map_to_graph(map_probs): def convert_soft_map_to_graph(map_probs: torch.Tensor):
""" """
直接使用 Softmax 概率构建 soft 图结构 直接使用 Softmax 概率构建 soft 图结构
""" """
@ -33,87 +31,3 @@ def convert_soft_map_to_graph(map_probs):
map_probs[:, edge_index[1] // W, edge_index[1] % W]) / 2 map_probs[:, edge_index[1] // W, edge_index[1] % W]) / 2
return Data(x=node_features, edge_index=edge_index, edge_attr=soft_edge_weight) return Data(x=node_features, edge_index=edge_index, edge_attr=soft_edge_weight)
def convert_map_to_graph(map):
rows = len(map)
cols = len(map[0])
node_indices = {}
valid_nodes = []
node_counter = 0
for r in range(rows):
for c in range(cols):
if map[r][c] != 1: # 排除墙体
node_indices[(r, c)] = node_counter
valid_nodes.append((r, c, map[r][c])) # (行, 列, 地形类型)
node_counter += 1
edge_list = []
for (r, c, _) in valid_nodes:
node = node_indices[(r, c)]
if c + 1 < cols and (r, c + 1) in node_indices:
edge_list.append((node, node_indices[(r, c + 1)]))
if r + 1 < rows and (r + 1, c) in node_indices:
edge_list.append((node, node_indices[(r + 1, c)]))
edge_index = torch.tensor(edge_list, dtype=torch.long).T
node_features = torch.tensor([node_type for (_, _, node_type) in valid_nodes], dtype=torch.long)
return Data(x=node_features, edge_index=edge_index)
class DynamicGraphConverter(nn.Module):
def __init__(self, map_size=13):
super().__init__()
self.map_size = map_size
self.n_nodes = map_size * map_size
self.base_edge_index = self._precompute_base_edges()
def _precompute_base_edges(self):
edge_list = []
directions = [(0, 1), (1, 0)]
for r in range(self.map_size):
for c in range(self.map_size):
node = r * self.map_size + c
for dr, dc in directions:
nr, nc = r + dr, c + dc
if 0 <= nr < self.map_size and 0 <= nc < self.map_size:
neighbor = nr * self.map_size + nc
edge_list.append([node, neighbor])
return torch.tensor(edge_list).t().contiguous().unique(dim=1)
def forward(self, map_probs, tau=0.5):
B, C, H, W = map_probs.shape
device = map_probs.device
self.base_edge_index = self.base_edge_index.to(device)
# 1. 计算可微的节点 ID
node_logits = map_probs.view(B, C, -1).permute(0, 2, 1) # [B, N, C]
hard_nodes = F.gumbel_softmax(node_logits, tau=tau, hard=True)
node_ids = (hard_nodes * torch.arange(C, device=device).view(1, 1, -1)).sum(dim=-1).long()
# 2. 计算 soft 壁障 mask
wall_mask = torch.sigmoid((node_ids - 1) * 10) # 类别 1 代表墙体soft 处理
edge_weights = self._compute_dynamic_weights(wall_mask)
# 3. 构建动态图
batch_data = []
for b in range(B):
soft_mask = torch.sigmoid((edge_weights[b] - 0.1) * 10) # 软门控
dynamic_edge_attr = edge_weights[b] * soft_mask # 仍然保留梯度
data = Data(
x=node_ids[b],
edge_index=self.base_edge_index,
edge_attr=dynamic_edge_attr
)
batch_data.append(data)
return Batch.from_data_list(batch_data)
def _compute_dynamic_weights(self, wall_mask):
src_nodes = self.base_edge_index[0]
dst_nodes = self.base_edge_index[1]
# 让梯度能正确回传
weights = 1 - (wall_mask[:, src_nodes] + wall_mask[:, dst_nodes]) / 2
return weights.unsqueeze(-1)