ginka-generator/ginka/validate.py

55 lines
1.8 KiB
Python

import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from minamo.model.model import MinamoModel
from .dataset import GinkaDataset
from .model.loss import GinkaLoss
from .model.model import GinkaModel
from shared.graph import DynamicGraphConverter
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def validate():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
model = GinkaModel()
minamo = MinamoModel(32)
minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
minamo.to(device)
# 准备数据集
val_dataset = GinkaDataset("ginka-eval.json")
val_loader = DataLoader(
val_dataset,
batch_size=32,
shuffle=True
)
converter = DynamicGraphConverter().to(device)
criterion = GinkaLoss(minamo, converter)
minamo.eval()
val_loss = 0
with torch.no_grad():
for batch in val_loader:
# 数据迁移到设备
target = batch["target"].to(device)
target_vision_feat = batch["target_vision_feat"].to(device)
target_topo_feat = batch["target_topo_feat"].to(device)
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
# 前向传播
output, _ = model(feat_vec)
map_matrix = torch.argmax(output, dim=1)
# 计算损失
loss = criterion(output, map_matrix, target, target_vision_feat, target_topo_feat)
total_loss += loss.item()
avg_val_loss = val_loss / len(val_loader)
tqdm.write(f"Validation::loss: {avg_val_loss:.6f}")
if __name__ == "__main__":
torch.set_num_threads(2)
validate()