ginka-generator/minamo/validate.py

49 lines
1.6 KiB
Python

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from .model.model import MinamoModel
from .model.loss import MinamoLoss
from .dataset import MinamoDataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def validate():
print(f"Using {"cuda" if torch.cuda.is_available() else "cpu"} to validate model.")
model = MinamoModel(32)
model.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
model.to(device)
# 准备数据集
val_dataset = MinamoDataset("F:/github-ai/ginka-generator/minamo-eval.json")
val_loader = DataLoader(
val_dataset,
batch_size=32,
shuffle=True
)
criterion = MinamoLoss(temp=0.8)
model.eval()
val_loss = 0
with torch.no_grad():
for val_batch in tqdm(val_loader):
map1_val, map2_val, vision_simi_val, topo_simi_val = val_batch
map1_val = map1_val.to(device)
map2_val = map2_val.to(device)
vision_simi_val = vision_simi_val.to(device)
topo_simi_val = topo_simi_val.to(device)
vision_pred_val, topo_pred_val = model(map1_val, map2_val)
loss_val = criterion(
vision_pred_val, topo_pred_val,
vision_simi_val, topo_simi_val
)
val_loss += loss_val.item()
avg_val_loss = val_loss / len(val_loader)
tqdm.write(f"Validation::loss: {avg_val_loss:.6f}")
if __name__ == "__main__":
torch.set_num_threads(2)
validate()