diff --git a/minamo/model/model.py b/minamo/model/model.py index 25c322f..1f09a65 100644 --- a/minamo/model/model.py +++ b/minamo/model/model.py @@ -18,7 +18,4 @@ class MinamoModel(nn.Module): topo_feat1 = self.topo_model(graph1) topo_feat2 = self.topo_model(graph2) - vision_sim = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1) - topo_sim = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1) - - return vision_sim, topo_sim + return vision_feat1, vision_feat2, topo_feat1, topo_feat2 diff --git a/minamo/train.py b/minamo/train.py index cbe481f..04a0103 100644 --- a/minamo/train.py +++ b/minamo/train.py @@ -2,6 +2,7 @@ import os from datetime import datetime import torch import torch.optim as optim +import torch.nn.functional as F from torch_geometric.loader import DataLoader from tqdm import tqdm from .model.model import MinamoModel @@ -75,7 +76,10 @@ def train(): # 前向传播 optimizer.zero_grad() - vision_pred, topo_pred = model(map1, map2, graph1, graph2) + vision_feat1, vision_feat2, topo_feat1, topo_feat2 = model(map1, map2, graph1, graph2) + + vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1) + topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1) # 计算损失 loss = criterion(vision_pred, topo_pred, vision_simi, topo_simi) @@ -117,11 +121,13 @@ def train(): graph1 = graph1.to(device) graph2 = graph2.to(device) - vision_pred_val, topo_pred_val = model(map1_val, map2_val, graph1, graph2) - loss_val = criterion( - vision_pred_val, topo_pred_val, - vision_simi_val, topo_simi_val - ) + vision_feat1, vision_feat2, topo_feat1, topo_feat2 = model(map1_val, map2_val, graph1, graph2) + + vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1) + topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1) + + # 计算损失 + loss_val = criterion(vision_pred, topo_pred, vision_simi, topo_simi) val_loss += loss_val.item() avg_val_loss = val_loss / len(val_loader)