import torch from torch.utils.data import DataLoader from tqdm import tqdm from .model.model import MinamoModel from .model.loss import MinamoLoss from .dataset import MinamoDataset device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def validate(): print(f"Using {"cuda" if torch.cuda.is_available() else "cpu"} to validate model.") model = MinamoModel(32) model.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"]) model.to(device) # 准备数据集 val_dataset = MinamoDataset("F:/github-ai/ginka-generator/minamo-eval.json") val_loader = DataLoader( val_dataset, batch_size=32, shuffle=True ) criterion = MinamoLoss(temp=0.8) model.eval() val_loss = 0 with torch.no_grad(): for val_batch in tqdm(val_loader): map1_val, map2_val, vision_simi_val, topo_simi_val = val_batch map1_val = map1_val.to(device) map2_val = map2_val.to(device) vision_simi_val = vision_simi_val.to(device) topo_simi_val = topo_simi_val.to(device) vision_pred_val, topo_pred_val = model(map1_val, map2_val) loss_val = criterion( vision_pred_val, topo_pred_val, vision_simi_val, topo_simi_val ) val_loss += loss_val.item() avg_val_loss = val_loss / len(val_loader) tqdm.write(f"Validation::loss: {avg_val_loss:.6f}") if __name__ == "__main__": torch.set_num_threads(2) validate()