From 09b66f25690bdd376bc87b9931bff6fad516c6f2 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Tue, 18 Mar 2025 18:22:19 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BD=99=E5=BC=A6=E7=9B=B8?= =?UTF-8?q?=E4=BC=BC=E5=BA=A6=E9=83=A8=E5=88=86=E5=8D=95=E7=8B=AC=E6=8B=89?= =?UTF-8?q?=E5=87=BA=E6=9D=A5=E8=AE=A1=E7=AE=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- minamo/model/model.py | 5 +---- minamo/train.py | 18 ++++++++++++------ 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/minamo/model/model.py b/minamo/model/model.py index 25c322f..1f09a65 100644 --- a/minamo/model/model.py +++ b/minamo/model/model.py @@ -18,7 +18,4 @@ class MinamoModel(nn.Module): topo_feat1 = self.topo_model(graph1) topo_feat2 = self.topo_model(graph2) - vision_sim = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1) - topo_sim = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1) - - return vision_sim, topo_sim + return vision_feat1, vision_feat2, topo_feat1, topo_feat2 diff --git a/minamo/train.py b/minamo/train.py index cbe481f..04a0103 100644 --- a/minamo/train.py +++ b/minamo/train.py @@ -2,6 +2,7 @@ import os from datetime import datetime import torch import torch.optim as optim +import torch.nn.functional as F from torch_geometric.loader import DataLoader from tqdm import tqdm from .model.model import MinamoModel @@ -75,7 +76,10 @@ def train(): # 前向传播 optimizer.zero_grad() - vision_pred, topo_pred = model(map1, map2, graph1, graph2) + vision_feat1, vision_feat2, topo_feat1, topo_feat2 = model(map1, map2, graph1, graph2) + + vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1) + topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1) # 计算损失 loss = criterion(vision_pred, topo_pred, vision_simi, topo_simi) @@ -117,11 +121,13 @@ def train(): graph1 = graph1.to(device) graph2 = graph2.to(device) - vision_pred_val, topo_pred_val = model(map1_val, map2_val, graph1, graph2) - loss_val = criterion( - vision_pred_val, topo_pred_val, - vision_simi_val, topo_simi_val - ) + vision_feat1, vision_feat2, topo_feat1, topo_feat2 = model(map1_val, map2_val, graph1, graph2) + + vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1) + topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1) + + # 计算损失 + loss_val = criterion(vision_pred, topo_pred, vision_simi, topo_simi) val_loss += loss_val.item() avg_val_loss = val_loss / len(val_loader)