refactor: 余弦相似度部分单独拉出来计算

This commit is contained in:
unanmed 2025-03-18 18:22:19 +08:00
parent 1ebd03390d
commit 09b66f2569
2 changed files with 13 additions and 10 deletions

View File

@ -18,7 +18,4 @@ class MinamoModel(nn.Module):
topo_feat1 = self.topo_model(graph1)
topo_feat2 = self.topo_model(graph2)
vision_sim = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
topo_sim = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
return vision_sim, topo_sim
return vision_feat1, vision_feat2, topo_feat1, topo_feat2

View File

@ -2,6 +2,7 @@ import os
from datetime import datetime
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from .model.model import MinamoModel
@ -75,7 +76,10 @@ def train():
# 前向传播
optimizer.zero_grad()
vision_pred, topo_pred = model(map1, map2, graph1, graph2)
vision_feat1, vision_feat2, topo_feat1, topo_feat2 = model(map1, map2, graph1, graph2)
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
# 计算损失
loss = criterion(vision_pred, topo_pred, vision_simi, topo_simi)
@ -117,11 +121,13 @@ def train():
graph1 = graph1.to(device)
graph2 = graph2.to(device)
vision_pred_val, topo_pred_val = model(map1_val, map2_val, graph1, graph2)
loss_val = criterion(
vision_pred_val, topo_pred_val,
vision_simi_val, topo_simi_val
)
vision_feat1, vision_feat2, topo_feat1, topo_feat2 = model(map1_val, map2_val, graph1, graph2)
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
# 计算损失
loss_val = criterion(vision_pred, topo_pred, vision_simi, topo_simi)
val_loss += loss_val.item()
avg_val_loss = val_loss / len(val_loader)