refactor: GINKA 生成器使用 Minamo 作为损失值的一部分

This commit is contained in:
unanmed 2025-03-18 23:51:14 +08:00
parent 1566acf691
commit 09c63fedce
6 changed files with 221 additions and 61 deletions

View File

@ -15,10 +15,11 @@ def load_data(path: str):
return data_list
class GinkaDataset(Dataset):
def __init__(self, data_path: str, minamo: MinamoModel):
def __init__(self, data_path: str, device, minamo: MinamoModel):
self.data = load_data(data_path) # 自定义数据加载函数
self.max_size = 32
self.minamo = minamo
self.device = device
def __len__(self):
return len(self.data)
@ -26,13 +27,13 @@ class GinkaDataset(Dataset):
def __getitem__(self, idx):
item = self.data[idx]
target = torch.tensor(item["map"])
graph = convert_map_to_graph(target)
vision_feat, topo_feat = self.minamo(target, graph)
feat_vec = torch.cat([vision_feat, topo_feat])
target = torch.tensor(item["map"]).to(self.device)
graph = convert_map_to_graph(target).to(self.device)
vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph)
return {
"feat_vec": feat_vec,
"target_vision_feat": vision_feat,
"target_topo_feat": topo_feat,
"target": target
}

View File

@ -3,6 +3,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from minamo.model.model import MinamoModel
from shared.graph import DynamicGraphConverter
def wall_border_loss(pred: torch.Tensor, probs: torch.Tensor, allow_border=[1, 11]):
"""地图最外层是否为墙"""
@ -131,7 +132,6 @@ def entrance_distance_and_presence_loss(
total_loss: 综合入口距离与存在性损失
"""
# 将 logits 转换为概率分布
probs = F.softmax(logits, dim=1) # [B, C, H, W]
B, C, H, W = logits.shape
# 提取箭头和楼梯的概率图
@ -147,9 +147,9 @@ def entrance_distance_and_presence_loss(
arrow_distance_loss = F.relu(arrow_excess).mean()
# 楼梯:使用窗口大小为 (W//2, H//2)
kernel_size_stairs = (max(1, W // 2), max(1, H // 2))
kernel_size_stairs = (9, 9)
kernel_stairs = torch.ones((1, 1, kernel_size_stairs[0], kernel_size_stairs[1]), device=logits.device)
pad_stairs = (kernel_size_stairs[0] // 2, kernel_size_stairs[1] // 2)
pad_stairs = ((kernel_size_stairs[0] - 1) // 2, (kernel_size_stairs[1] - 1) // 2)
local_stairs_sum = F.conv2d(stairs_probs.unsqueeze(1), kernel_stairs, padding=pad_stairs)
stairs_excess = local_stairs_sum - stairs_probs.unsqueeze(1)
stairs_distance_loss = F.relu(stairs_excess).mean()
@ -283,7 +283,7 @@ def integrated_count_loss(probs, target, class_list=[0,1,2,3,4,5,6,7,8,9], toler
return avg_loss
class GinkaLoss(nn.Module):
def __init__(self, minamo: MinamoModel, weight=[0.35, 0.1, 0.1, 0.1, 0.1, 0.05, 0.1, 0.1]):
def __init__(self, minamo: MinamoModel, converter: DynamicGraphConverter, weight=[0.35, 0.1, 0.1, 0.1, 0.1, 0.05, 0.1, 0.1]):
"""Ginka Model 损失函数部分
Args:
@ -301,8 +301,10 @@ class GinkaLoss(nn.Module):
self.weight = weight
self.ce = nn.CrossEntropyLoss()
self.minamo = minamo
self.tau = 1
self.converter = converter
def forward(self, pred, pred_softmax, target):
def forward(self, pred, pred_softmax, target, target_vision_feat, target_topo_feat):
probs = F.softmax(pred, dim=1)
# 地图结构损失
border_loss = wall_border_loss(pred, probs)
@ -314,21 +316,27 @@ class GinkaLoss(nn.Module):
count_loss = integrated_count_loss(probs, target)
# 使用 Minamo Model 计算相似度
graph = self.converter(pred, tau=self.tau)
pred_vision_feat, pred_topo_feat = self.minamo(pred_softmax, graph)
vision_sim = F.cosine_similarity(pred_vision_feat, target_vision_feat, dim=-1)
topo_sim = F.cosine_similarity(pred_topo_feat, target_topo_feat, dim=-1)
minamo_sim = 0.3 * vision_sim + 0.7 * topo_sim
minamo_loss = torch.exp(-10 * (minamo_sim - 0.8)).mean()
print(
# structure_loss.item(),
border_loss.item(),
wall_loss.item(),
entry_loss.item(),
entry_dis_loss.item(),
enemy_loss.item(),
valid_block_loss.item(),
count_loss.item()
)
# print(
# minamo_loss.item(),
# border_loss.item(),
# wall_loss.item(),
# entry_loss.item(),
# entry_dis_loss.item(),
# enemy_loss.item(),
# valid_block_loss.item(),
# count_loss.item()
# )
return (
# structure_loss * self.weight[0] +
minamo_loss * self.weight[0] +
border_loss * self.weight[1] +
wall_loss * self.weight[2] +
entry_loss * self.weight[3] +

View File

@ -14,8 +14,8 @@ class GumbelSoftmax(nn.Module):
y = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
# 转换为类索引的连续表示
class_indices = torch.arange(y.size(1), device=y.device).view(1, -1, 1, 1)
return (y * class_indices).sum(dim=1) # 形状[BS, H, W]
# class_indices = torch.arange(y.size(1), device=y.device).view(1, -1, 1, 1)
return y.argmax(dim=1) # 形状[BS, H, W]
class GinkaModel(nn.Module):
def __init__(self, feat_dim=256, base_ch=64, num_classes=32):

View File

@ -16,33 +16,25 @@ class GinkaEncoder(nn.Module):
self.pool = nn.MaxPool2d(2)
def forward(self, x):
x_res = self.conv(x) # 卷积提取特征
x_down = self.pool(x_res) # 进行池化
return x_down, x_res # 返回池化后的特征和跳跃连接特征
x_res = self.conv(x)
x_down = self.pool(x_res)
return x_down, x_res
class GinkaDecoder(nn.Module):
"""解码器(上采样)部分"""
def __init__(self, in_channels, out_channels):
super().__init__()
# 上采样(双线性插值 + 卷积)
self.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
# 跳跃连接融合
self.upsample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
self.conv = nn.Sequential(
nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1),
nn.Conv2d(in_channels + out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self, x, skip):
x = self.upsample(x)
# 跳跃连接融合
x = torch.cat([x, skip], dim=1)
x = torch.cat([x, skip], dim=1)
x = self.conv(x)
return x
@ -66,26 +58,26 @@ class GinkaUNet(nn.Module):
"""Ginka Model UNet 部分
"""
super().__init__()
self.down1 = GinkaEncoder(in_ch, in_ch*2)
self.down2 = GinkaEncoder(in_ch*2, in_ch*4)
self.bottleneck = GinkaBottleneck(in_ch*4, in_ch*4)
self.up1 = GinkaDecoder(in_ch*4, in_ch*2)
self.up2 = GinkaDecoder(in_ch*2, in_ch)
self.final = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 1)
nn.Conv2d(in_ch, out_ch, 1),
# nn.Softmax(dim=1) # 适用于分类任务
)
def forward(self, x):
x, skip1 = self.down1(x)
x, skip2 = self.down2(x)
x = self.bottleneck(x)
x = self.up1(x, skip2)
x = self.up2(x, skip1)
x_down1, skip1 = self.down1(x)
x_down2, skip2 = self.down2(x_down1)
x = self.bottleneck(x_down2)
x = self.up1(x, skip2) # 用 down2 的 skip
x = self.up2(x, skip1) # 用 down1 的 skip
return self.final(x)

View File

@ -2,17 +2,17 @@ import os
from datetime import datetime
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import BertTokenizer
from tqdm import tqdm
from .model.model import GinkaModel
from .model.loss import GinkaLoss
from .dataset import GinkaDataset
from minamo.model.model import MinamoModel
from shared.graph import DynamicGraphConverter
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
os.makedirs("result/ginka_checkpoint", exist_ok=True)
epochs = 70
@ -29,48 +29,85 @@ def train():
minamo = MinamoModel(32)
minamo.to(device)
minamo.eval()
converter = DynamicGraphConverter().to(device)
# 准备数据集
dataset = GinkaDataset("dataset.json", minamo)
dataset = GinkaDataset("ginka-dataset.json", device, minamo)
dataset_val = GinkaDataset("ginka-eval.json", device, minamo)
dataloader = DataLoader(
dataset,
batch_size=32,
shuffle=True
)
dataloader_val = DataLoader(
dataset_val,
batch_size=32,
shuffle=True
)
# 设定优化器与调度器
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
criterion = GinkaLoss(minamo)
criterion = GinkaLoss(minamo, converter)
# 开始训练
for epoch in tqdm(range(epochs)):
model.train()
total_loss = 0
model.softmax.tau = update_tau(epoch)
criterion.tau = update_tau(epoch)
for batch in dataloader:
# 数据迁移到设备
target = batch["target"].to(device)
feat_vec = batch["feat_vec"].to(device)
target_vision_feat = batch["target_vision_feat"].to(device)
target_topo_feat = batch["target_topo_feat"].to(device)
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
# 前向传播
optimizer.zero_grad()
output, output_softmax = model(feat_vec)
# 计算损失
loss = criterion(output, output_softmax, target)
loss = criterion(output, output_softmax, target, target_vision_feat, target_topo_feat)
# 反向传播
loss.backward()
optimizer.step()
total_loss += loss.item()
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch} | loss: {total_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
avg_loss = total_loss / len(dataloader)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
# 学习率调整
scheduler.step()
if (epoch + 1) % 5 == 0:
loss_val = 0
model.eval()
with torch.no_grad():
for batch in dataloader_val:
# 数据迁移到设备
target = batch["target"].to(device)
target_vision_feat = batch["target_vision_feat"].to(device)
target_topo_feat = batch["target_topo_feat"].to(device)
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
# 前向传播
output, output_softmax = model(feat_vec)
# 计算损失
loss = criterion(output, output_softmax, target, target_vision_feat, target_topo_feat)
loss_val += loss.item()
avg_val_loss = loss_val / len(dataloader_val)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
torch.save({
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
}, f"result/ginka_checkpoint/{epoch + 1}.pth")
print("Train ended.")
torch.save({

View File

@ -1,5 +1,8 @@
import torch
from torch_geometric.data import Data
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
from torch_geometric.utils import dense_to_sparse
def convert_map_to_graph(map):
rows = len(map)
@ -26,4 +29,123 @@ def convert_map_to_graph(map):
edge_index = torch.tensor(edge_list, dtype=torch.long).T
node_features = torch.tensor([node_type for (_, _, node_type) in valid_nodes], dtype=torch.long)
return Data(x=node_features, edge_index=edge_index)
return Data(x=node_features, edge_index=edge_index)
def soft_convert_map_to_graph(map_tensor, tau=0.5, threshold=0.1):
"""
将地图批量转换为 GNN 可用的 Graph Data使用 soft 策略
"""
B, H, W, C = map_tensor.shape
N = H * W
# 使用 Gumbel-Softmax 确保是 one-hot 形式
y = F.gumbel_softmax(map_tensor.view(B, N, C), tau=tau, hard=True) # [B, N, C]
# 计算整数索引(用于 embedding
node_features = y.argmax(dim=-1).long().unsqueeze(-1) # [B, N, 1]
# 取出墙体的 soft mask
wall_mask = y[:, :, 1] # [B, N]
adjacency_matrix = torch.zeros(B, N, N, device=map_tensor.device)
def get_index(r, c):
return r * W + c
for r in range(H):
for c in range(W):
idx = get_index(r, c)
if c + 1 < W:
right_idx = get_index(r, c + 1)
adjacency_matrix[:, idx, right_idx] = (1 - wall_mask[:, idx]) * (1 - wall_mask[:, right_idx])
if r + 1 < H:
down_idx = get_index(r + 1, c)
adjacency_matrix[:, idx, down_idx] = (1 - wall_mask[:, idx]) * (1 - wall_mask[:, down_idx])
edge_index_list, edge_weight_list = [], []
for b in range(B):
adj_bin = (adjacency_matrix[b] > threshold).float()
edge_index = torch.nonzero(adj_bin, as_tuple=False).T # [2, E]
edge_weight = adjacency_matrix[b][edge_index[0], edge_index[1]]
edge_index_list.append(edge_index)
edge_weight_list.append(edge_weight)
edge_index = torch.cat(edge_index_list, dim=1)
edge_weight = torch.cat(edge_weight_list, dim=0)
return Data(x=node_features, edge_index=edge_index, edge_attr=edge_weight)
class DynamicGraphConverter(nn.Module):
def __init__(self, map_size=13):
super().__init__()
self.map_size = map_size
self.n_nodes = map_size * map_size
# 预计算所有可能的边索引组合(包括对角线)
self.base_edge_index = self._precompute_base_edges()
def _precompute_base_edges(self):
"""预生成全连接边索引(包含所有可能邻接)"""
edge_list = []
directions = [
(0, 1), # 右
(1, 0), # 下
]
for r in range(self.map_size):
for c in range(self.map_size):
node = r * self.map_size + c
for dr, dc in directions:
nr, nc = r + dr, c + dc
if 0 <= nr < self.map_size and 0 <= nc < self.map_size:
neighbor = nr * self.map_size + nc
edge_list.append([node, neighbor])
return torch.tensor(edge_list).t().contiguous().unique(dim=1)
def forward(self, map_probs, tau=0.5):
B, C, H, W = map_probs.shape
device = map_probs.device
self.base_edge_index = self.base_edge_index.to(device)
# 1. 节点特征离散化(保持可导)
node_logits = map_probs.view(B, C, -1).permute(0, 2, 1) # [B, N, C]
hard_nodes = F.gumbel_softmax(node_logits, tau=tau, hard=True)
node_ids = hard_nodes.argmax(dim=-1) # [B, N]
# 2. 动态边权重计算
wall_mask = (node_ids == 1).float() # 假设类别1是墙体
edge_weights = self._compute_dynamic_weights(wall_mask)
# 3. 构建动态图
batch_data = []
for b in range(B):
# 动态过滤无效边(与墙体相连的边)
valid_mask = (edge_weights[b] > 0.1).squeeze(-1)
dynamic_edge_index = self.base_edge_index[:, valid_mask]
dynamic_edge_attr = edge_weights[b][valid_mask]
data = Data(
x=node_ids[b],
edge_index=dynamic_edge_index,
edge_attr=dynamic_edge_attr
)
batch_data.append(data)
return Batch.from_data_list(batch_data)
def _compute_dynamic_weights(self, wall_mask):
"""基于墙体存在性计算动态边权重"""
# wall_mask: [B, N]
src_nodes = self.base_edge_index[0] # [E]
dst_nodes = self.base_edge_index[1] # [E]
# 边权重 = 1 - (源是墙 OR 目标墙)
weights = 1 - torch.logical_or(
wall_mask[:, src_nodes],
wall_mask[:, dst_nodes]
).float() # [B, E]
return weights.unsqueeze(-1) # [B, E, 1]