mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
56 lines
1.9 KiB
Python
56 lines
1.9 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
from torch_geometric.loader import DataLoader
|
|
from tqdm import tqdm
|
|
from .model.model import MinamoModel
|
|
from .model.loss import MinamoLoss
|
|
from .dataset import MinamoDataset
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
def validate():
|
|
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
|
|
model = MinamoModel(32)
|
|
model.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
|
|
model.to(device)
|
|
|
|
# 准备数据集
|
|
val_dataset = MinamoDataset("minamo-eval.json")
|
|
val_loader = DataLoader(
|
|
val_dataset,
|
|
batch_size=32,
|
|
shuffle=True
|
|
)
|
|
|
|
criterion = MinamoLoss()
|
|
|
|
model.eval()
|
|
val_loss = 0
|
|
with torch.no_grad():
|
|
for val_batch in tqdm(val_loader):
|
|
map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = val_batch
|
|
map1_val = map1_val.to(device)
|
|
map2_val = map2_val.to(device)
|
|
vision_simi_val = vision_simi_val.to(device)
|
|
topo_simi_val = topo_simi_val.to(device)
|
|
graph1 = graph1.to(device)
|
|
graph2 = graph2.to(device)
|
|
|
|
vision_feat1, topo_feat1 = model(map1_val, graph1)
|
|
vision_feat2, topo_feat2 = model(map2_val, graph2)
|
|
|
|
vision_pred_val = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
|
topo_pred_val = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
|
loss_val = criterion(
|
|
vision_pred_val, topo_pred_val,
|
|
vision_simi_val, topo_simi_val
|
|
)
|
|
val_loss += loss_val.item()
|
|
|
|
avg_val_loss = val_loss / len(val_loader)
|
|
tqdm.write(f"Validation::loss: {avg_val_loss:.6f}")
|
|
|
|
if __name__ == "__main__":
|
|
torch.set_num_threads(2)
|
|
validate()
|
|
|