From 09c63fedce8f13a3dd6e36986560830d1f00ac98 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Tue, 18 Mar 2025 23:51:14 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20GINKA=20=E7=94=9F=E6=88=90=E5=99=A8?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=20Minamo=20=E4=BD=9C=E4=B8=BA=E6=8D=9F?= =?UTF-8?q?=E5=A4=B1=E5=80=BC=E7=9A=84=E4=B8=80=E9=83=A8=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/dataset.py | 13 ++--- ginka/model/loss.py | 40 ++++++++------ ginka/model/model.py | 4 +- ginka/model/unet.py | 46 +++++++--------- ginka/train.py | 53 +++++++++++++++--- shared/graph.py | 126 ++++++++++++++++++++++++++++++++++++++++++- 6 files changed, 221 insertions(+), 61 deletions(-) diff --git a/ginka/dataset.py b/ginka/dataset.py index d1158eb..1f92150 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -15,10 +15,11 @@ def load_data(path: str): return data_list class GinkaDataset(Dataset): - def __init__(self, data_path: str, minamo: MinamoModel): + def __init__(self, data_path: str, device, minamo: MinamoModel): self.data = load_data(data_path) # 自定义数据加载函数 self.max_size = 32 self.minamo = minamo + self.device = device def __len__(self): return len(self.data) @@ -26,13 +27,13 @@ class GinkaDataset(Dataset): def __getitem__(self, idx): item = self.data[idx] - target = torch.tensor(item["map"]) - graph = convert_map_to_graph(target) - vision_feat, topo_feat = self.minamo(target, graph) - feat_vec = torch.cat([vision_feat, topo_feat]) + target = torch.tensor(item["map"]).to(self.device) + graph = convert_map_to_graph(target).to(self.device) + vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph) return { - "feat_vec": feat_vec, + "target_vision_feat": vision_feat, + "target_topo_feat": topo_feat, "target": target } \ No newline at end of file diff --git a/ginka/model/loss.py b/ginka/model/loss.py index 304d621..ea1e495 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from minamo.model.model import MinamoModel +from shared.graph import DynamicGraphConverter def wall_border_loss(pred: torch.Tensor, probs: torch.Tensor, allow_border=[1, 11]): """地图最外层是否为墙""" @@ -131,7 +132,6 @@ def entrance_distance_and_presence_loss( total_loss: 综合入口距离与存在性损失 """ # 将 logits 转换为概率分布 - probs = F.softmax(logits, dim=1) # [B, C, H, W] B, C, H, W = logits.shape # 提取箭头和楼梯的概率图 @@ -147,9 +147,9 @@ def entrance_distance_and_presence_loss( arrow_distance_loss = F.relu(arrow_excess).mean() # 楼梯:使用窗口大小为 (W//2, H//2) - kernel_size_stairs = (max(1, W // 2), max(1, H // 2)) + kernel_size_stairs = (9, 9) kernel_stairs = torch.ones((1, 1, kernel_size_stairs[0], kernel_size_stairs[1]), device=logits.device) - pad_stairs = (kernel_size_stairs[0] // 2, kernel_size_stairs[1] // 2) + pad_stairs = ((kernel_size_stairs[0] - 1) // 2, (kernel_size_stairs[1] - 1) // 2) local_stairs_sum = F.conv2d(stairs_probs.unsqueeze(1), kernel_stairs, padding=pad_stairs) stairs_excess = local_stairs_sum - stairs_probs.unsqueeze(1) stairs_distance_loss = F.relu(stairs_excess).mean() @@ -283,7 +283,7 @@ def integrated_count_loss(probs, target, class_list=[0,1,2,3,4,5,6,7,8,9], toler return avg_loss class GinkaLoss(nn.Module): - def __init__(self, minamo: MinamoModel, weight=[0.35, 0.1, 0.1, 0.1, 0.1, 0.05, 0.1, 0.1]): + def __init__(self, minamo: MinamoModel, converter: DynamicGraphConverter, weight=[0.35, 0.1, 0.1, 0.1, 0.1, 0.05, 0.1, 0.1]): """Ginka Model 损失函数部分 Args: @@ -301,8 +301,10 @@ class GinkaLoss(nn.Module): self.weight = weight self.ce = nn.CrossEntropyLoss() self.minamo = minamo + self.tau = 1 + self.converter = converter - def forward(self, pred, pred_softmax, target): + def forward(self, pred, pred_softmax, target, target_vision_feat, target_topo_feat): probs = F.softmax(pred, dim=1) # 地图结构损失 border_loss = wall_border_loss(pred, probs) @@ -314,21 +316,27 @@ class GinkaLoss(nn.Module): count_loss = integrated_count_loss(probs, target) # 使用 Minamo Model 计算相似度 + graph = self.converter(pred, tau=self.tau) + pred_vision_feat, pred_topo_feat = self.minamo(pred_softmax, graph) + vision_sim = F.cosine_similarity(pred_vision_feat, target_vision_feat, dim=-1) + topo_sim = F.cosine_similarity(pred_topo_feat, target_topo_feat, dim=-1) + minamo_sim = 0.3 * vision_sim + 0.7 * topo_sim + minamo_loss = torch.exp(-10 * (minamo_sim - 0.8)).mean() - print( - # structure_loss.item(), - border_loss.item(), - wall_loss.item(), - entry_loss.item(), - entry_dis_loss.item(), - enemy_loss.item(), - valid_block_loss.item(), - count_loss.item() - ) + # print( + # minamo_loss.item(), + # border_loss.item(), + # wall_loss.item(), + # entry_loss.item(), + # entry_dis_loss.item(), + # enemy_loss.item(), + # valid_block_loss.item(), + # count_loss.item() + # ) return ( - # structure_loss * self.weight[0] + + minamo_loss * self.weight[0] + border_loss * self.weight[1] + wall_loss * self.weight[2] + entry_loss * self.weight[3] + diff --git a/ginka/model/model.py b/ginka/model/model.py index f79ce49..113ea22 100644 --- a/ginka/model/model.py +++ b/ginka/model/model.py @@ -14,8 +14,8 @@ class GumbelSoftmax(nn.Module): y = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard) # 转换为类索引的连续表示 - class_indices = torch.arange(y.size(1), device=y.device).view(1, -1, 1, 1) - return (y * class_indices).sum(dim=1) # 形状[BS, H, W] + # class_indices = torch.arange(y.size(1), device=y.device).view(1, -1, 1, 1) + return y.argmax(dim=1) # 形状[BS, H, W] class GinkaModel(nn.Module): def __init__(self, feat_dim=256, base_ch=64, num_classes=32): diff --git a/ginka/model/unet.py b/ginka/model/unet.py index 58fee2f..ec0f2d3 100644 --- a/ginka/model/unet.py +++ b/ginka/model/unet.py @@ -16,33 +16,25 @@ class GinkaEncoder(nn.Module): self.pool = nn.MaxPool2d(2) def forward(self, x): - x_res = self.conv(x) # 卷积提取特征 - x_down = self.pool(x_res) # 进行池化 - return x_down, x_res # 返回池化后的特征和跳跃连接特征 + x_res = self.conv(x) + x_down = self.pool(x_res) + return x_down, x_res class GinkaDecoder(nn.Module): """解码器(上采样)部分""" def __init__(self, in_channels, out_channels): super().__init__() - # 上采样(双线性插值 + 卷积) - self.upsample = nn.Sequential( - nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(out_channels), - nn.ReLU() - ) - - # 跳跃连接融合 + self.upsample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) + self.conv = nn.Sequential( - nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1), + nn.Conv2d(in_channels + out_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU() ) def forward(self, x, skip): x = self.upsample(x) - # 跳跃连接融合 - x = torch.cat([x, skip], dim=1) + x = torch.cat([x, skip], dim=1) x = self.conv(x) return x @@ -66,26 +58,26 @@ class GinkaUNet(nn.Module): """Ginka Model UNet 部分 """ super().__init__() - self.down1 = GinkaEncoder(in_ch, in_ch*2) self.down2 = GinkaEncoder(in_ch*2, in_ch*4) - + self.bottleneck = GinkaBottleneck(in_ch*4, in_ch*4) - + self.up1 = GinkaDecoder(in_ch*4, in_ch*2) self.up2 = GinkaDecoder(in_ch*2, in_ch) - + self.final = nn.Sequential( - nn.Conv2d(in_ch, out_ch, 1) + nn.Conv2d(in_ch, out_ch, 1), + # nn.Softmax(dim=1) # 适用于分类任务 ) def forward(self, x): - x, skip1 = self.down1(x) - x, skip2 = self.down2(x) - - x = self.bottleneck(x) - - x = self.up1(x, skip2) - x = self.up2(x, skip1) + x_down1, skip1 = self.down1(x) + x_down2, skip2 = self.down2(x_down1) + + x = self.bottleneck(x_down2) + + x = self.up1(x, skip2) # 用 down2 的 skip + x = self.up2(x, skip1) # 用 down1 的 skip return self.final(x) diff --git a/ginka/train.py b/ginka/train.py index 076de3b..eb3559d 100644 --- a/ginka/train.py +++ b/ginka/train.py @@ -2,17 +2,17 @@ import os from datetime import datetime import torch import torch.optim as optim -import torch.nn.functional as F from torch.utils.data import DataLoader -from transformers import BertTokenizer from tqdm import tqdm from .model.model import GinkaModel from .model.loss import GinkaLoss from .dataset import GinkaDataset from minamo.model.model import MinamoModel +from shared.graph import DynamicGraphConverter device = torch.device("cuda" if torch.cuda.is_available() else "cpu") os.makedirs("result", exist_ok=True) +os.makedirs("result/ginka_checkpoint", exist_ok=True) epochs = 70 @@ -29,48 +29,85 @@ def train(): minamo = MinamoModel(32) minamo.to(device) minamo.eval() + + converter = DynamicGraphConverter().to(device) # 准备数据集 - dataset = GinkaDataset("dataset.json", minamo) + dataset = GinkaDataset("ginka-dataset.json", device, minamo) + dataset_val = GinkaDataset("ginka-eval.json", device, minamo) dataloader = DataLoader( dataset, batch_size=32, shuffle=True ) + dataloader_val = DataLoader( + dataset_val, + batch_size=32, + shuffle=True + ) # 设定优化器与调度器 optimizer = optim.AdamW(model.parameters(), lr=3e-4) scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6) - criterion = GinkaLoss(minamo) + criterion = GinkaLoss(minamo, converter) # 开始训练 for epoch in tqdm(range(epochs)): model.train() total_loss = 0 model.softmax.tau = update_tau(epoch) + criterion.tau = update_tau(epoch) for batch in dataloader: # 数据迁移到设备 target = batch["target"].to(device) - feat_vec = batch["feat_vec"].to(device) - + target_vision_feat = batch["target_vision_feat"].to(device) + target_topo_feat = batch["target_topo_feat"].to(device) + feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device) # 前向传播 optimizer.zero_grad() output, output_softmax = model(feat_vec) # 计算损失 - loss = criterion(output, output_softmax, target) + loss = criterion(output, output_softmax, target, target_vision_feat, target_topo_feat) # 反向传播 loss.backward() optimizer.step() total_loss += loss.item() - tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch} | loss: {total_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}") + avg_loss = total_loss / len(dataloader) + tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}") # 学习率调整 scheduler.step() + if (epoch + 1) % 5 == 0: + loss_val = 0 + model.eval() + with torch.no_grad(): + for batch in dataloader_val: + # 数据迁移到设备 + target = batch["target"].to(device) + target_vision_feat = batch["target_vision_feat"].to(device) + target_topo_feat = batch["target_topo_feat"].to(device) + feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device) + + # 前向传播 + output, output_softmax = model(feat_vec) + + # 计算损失 + loss = criterion(output, output_softmax, target, target_vision_feat, target_topo_feat) + loss_val += loss.item() + + avg_val_loss = loss_val / len(dataloader_val) + tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}") + torch.save({ + "model_state": model.state_dict(), + "optimizer_state": optimizer.state_dict(), + }, f"result/ginka_checkpoint/{epoch + 1}.pth") + + print("Train ended.") torch.save({ diff --git a/shared/graph.py b/shared/graph.py index acbed92..6dc9e3e 100644 --- a/shared/graph.py +++ b/shared/graph.py @@ -1,5 +1,8 @@ import torch -from torch_geometric.data import Data +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.data import Data, Batch +from torch_geometric.utils import dense_to_sparse def convert_map_to_graph(map): rows = len(map) @@ -26,4 +29,123 @@ def convert_map_to_graph(map): edge_index = torch.tensor(edge_list, dtype=torch.long).T node_features = torch.tensor([node_type for (_, _, node_type) in valid_nodes], dtype=torch.long) - return Data(x=node_features, edge_index=edge_index) \ No newline at end of file + return Data(x=node_features, edge_index=edge_index) + +def soft_convert_map_to_graph(map_tensor, tau=0.5, threshold=0.1): + """ + 将地图批量转换为 GNN 可用的 Graph Data,使用 soft 策略 + """ + B, H, W, C = map_tensor.shape + N = H * W + + # 使用 Gumbel-Softmax 确保是 one-hot 形式 + y = F.gumbel_softmax(map_tensor.view(B, N, C), tau=tau, hard=True) # [B, N, C] + + # 计算整数索引(用于 embedding) + node_features = y.argmax(dim=-1).long().unsqueeze(-1) # [B, N, 1] + + # 取出墙体的 soft mask + wall_mask = y[:, :, 1] # [B, N] + + adjacency_matrix = torch.zeros(B, N, N, device=map_tensor.device) + + def get_index(r, c): + return r * W + c + + for r in range(H): + for c in range(W): + idx = get_index(r, c) + if c + 1 < W: + right_idx = get_index(r, c + 1) + adjacency_matrix[:, idx, right_idx] = (1 - wall_mask[:, idx]) * (1 - wall_mask[:, right_idx]) + if r + 1 < H: + down_idx = get_index(r + 1, c) + adjacency_matrix[:, idx, down_idx] = (1 - wall_mask[:, idx]) * (1 - wall_mask[:, down_idx]) + + edge_index_list, edge_weight_list = [], [] + for b in range(B): + adj_bin = (adjacency_matrix[b] > threshold).float() + edge_index = torch.nonzero(adj_bin, as_tuple=False).T # [2, E] + edge_weight = adjacency_matrix[b][edge_index[0], edge_index[1]] + + edge_index_list.append(edge_index) + edge_weight_list.append(edge_weight) + + edge_index = torch.cat(edge_index_list, dim=1) + edge_weight = torch.cat(edge_weight_list, dim=0) + + return Data(x=node_features, edge_index=edge_index, edge_attr=edge_weight) + +class DynamicGraphConverter(nn.Module): + def __init__(self, map_size=13): + super().__init__() + self.map_size = map_size + self.n_nodes = map_size * map_size + + # 预计算所有可能的边索引组合(包括对角线) + self.base_edge_index = self._precompute_base_edges() + + def _precompute_base_edges(self): + """预生成全连接边索引(包含所有可能邻接)""" + edge_list = [] + directions = [ + (0, 1), # 右 + (1, 0), # 下 + ] + + for r in range(self.map_size): + for c in range(self.map_size): + node = r * self.map_size + c + for dr, dc in directions: + nr, nc = r + dr, c + dc + if 0 <= nr < self.map_size and 0 <= nc < self.map_size: + neighbor = nr * self.map_size + nc + edge_list.append([node, neighbor]) + + return torch.tensor(edge_list).t().contiguous().unique(dim=1) + + def forward(self, map_probs, tau=0.5): + B, C, H, W = map_probs.shape + device = map_probs.device + + self.base_edge_index = self.base_edge_index.to(device) + + # 1. 节点特征离散化(保持可导) + node_logits = map_probs.view(B, C, -1).permute(0, 2, 1) # [B, N, C] + hard_nodes = F.gumbel_softmax(node_logits, tau=tau, hard=True) + node_ids = hard_nodes.argmax(dim=-1) # [B, N] + + # 2. 动态边权重计算 + wall_mask = (node_ids == 1).float() # 假设类别1是墙体 + edge_weights = self._compute_dynamic_weights(wall_mask) + + # 3. 构建动态图 + batch_data = [] + for b in range(B): + # 动态过滤无效边(与墙体相连的边) + valid_mask = (edge_weights[b] > 0.1).squeeze(-1) + dynamic_edge_index = self.base_edge_index[:, valid_mask] + dynamic_edge_attr = edge_weights[b][valid_mask] + + data = Data( + x=node_ids[b], + edge_index=dynamic_edge_index, + edge_attr=dynamic_edge_attr + ) + batch_data.append(data) + + return Batch.from_data_list(batch_data) + + def _compute_dynamic_weights(self, wall_mask): + """基于墙体存在性计算动态边权重""" + # wall_mask: [B, N] + src_nodes = self.base_edge_index[0] # [E] + dst_nodes = self.base_edge_index[1] # [E] + + # 边权重 = 1 - (源是墙 OR 目标墙) + weights = 1 - torch.logical_or( + wall_mask[:, src_nodes], + wall_mask[:, dst_nodes] + ).float() # [B, E] + + return weights.unsqueeze(-1) # [B, E, 1]