mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-15 05:11:10 +08:00
49 lines
1.6 KiB
Python
49 lines
1.6 KiB
Python
import torch
|
|
from torch.utils.data import DataLoader
|
|
from tqdm import tqdm
|
|
from .model.model import MinamoModel
|
|
from .model.loss import MinamoLoss
|
|
from .dataset import MinamoDataset
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
def validate():
|
|
print(f"Using {"cuda" if torch.cuda.is_available() else "cpu"} to validate model.")
|
|
model = MinamoModel(32)
|
|
model.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
|
|
model.to(device)
|
|
|
|
# 准备数据集
|
|
val_dataset = MinamoDataset("F:/github-ai/ginka-generator/minamo-eval.json")
|
|
val_loader = DataLoader(
|
|
val_dataset,
|
|
batch_size=32,
|
|
shuffle=True
|
|
)
|
|
|
|
criterion = MinamoLoss(temp=0.8)
|
|
|
|
model.eval()
|
|
val_loss = 0
|
|
with torch.no_grad():
|
|
for val_batch in tqdm(val_loader):
|
|
map1_val, map2_val, vision_simi_val, topo_simi_val = val_batch
|
|
map1_val = map1_val.to(device)
|
|
map2_val = map2_val.to(device)
|
|
vision_simi_val = vision_simi_val.to(device)
|
|
topo_simi_val = topo_simi_val.to(device)
|
|
|
|
vision_pred_val, topo_pred_val = model(map1_val, map2_val)
|
|
loss_val = criterion(
|
|
vision_pred_val, topo_pred_val,
|
|
vision_simi_val, topo_simi_val
|
|
)
|
|
val_loss += loss_val.item()
|
|
|
|
avg_val_loss = val_loss / len(val_loader)
|
|
tqdm.write(f"Validation::loss: {avg_val_loss:.6f}")
|
|
|
|
if __name__ == "__main__":
|
|
torch.set_num_threads(2)
|
|
validate()
|
|
|