mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
refactor: 余弦相似度部分单独拉出来计算
This commit is contained in:
parent
1ebd03390d
commit
09b66f2569
@ -18,7 +18,4 @@ class MinamoModel(nn.Module):
|
||||
topo_feat1 = self.topo_model(graph1)
|
||||
topo_feat2 = self.topo_model(graph2)
|
||||
|
||||
vision_sim = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
||||
topo_sim = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||
|
||||
return vision_sim, topo_sim
|
||||
return vision_feat1, vision_feat2, topo_feat1, topo_feat2
|
||||
|
||||
@ -2,6 +2,7 @@ import os
|
||||
from datetime import datetime
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.loader import DataLoader
|
||||
from tqdm import tqdm
|
||||
from .model.model import MinamoModel
|
||||
@ -75,7 +76,10 @@ def train():
|
||||
|
||||
# 前向传播
|
||||
optimizer.zero_grad()
|
||||
vision_pred, topo_pred = model(map1, map2, graph1, graph2)
|
||||
vision_feat1, vision_feat2, topo_feat1, topo_feat2 = model(map1, map2, graph1, graph2)
|
||||
|
||||
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
||||
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||
|
||||
# 计算损失
|
||||
loss = criterion(vision_pred, topo_pred, vision_simi, topo_simi)
|
||||
@ -117,11 +121,13 @@ def train():
|
||||
graph1 = graph1.to(device)
|
||||
graph2 = graph2.to(device)
|
||||
|
||||
vision_pred_val, topo_pred_val = model(map1_val, map2_val, graph1, graph2)
|
||||
loss_val = criterion(
|
||||
vision_pred_val, topo_pred_val,
|
||||
vision_simi_val, topo_simi_val
|
||||
)
|
||||
vision_feat1, vision_feat2, topo_feat1, topo_feat2 = model(map1_val, map2_val, graph1, graph2)
|
||||
|
||||
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
||||
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||
|
||||
# 计算损失
|
||||
loss_val = criterion(vision_pred, topo_pred, vision_simi, topo_simi)
|
||||
val_loss += loss_val.item()
|
||||
|
||||
avg_val_loss = val_loss / len(val_loader)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user