mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-16 14:31:11 +08:00
55 lines
1.8 KiB
Python
55 lines
1.8 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
from torch_geometric.loader import DataLoader
|
|
from tqdm import tqdm
|
|
from minamo.model.model import MinamoModel
|
|
from .dataset import GinkaDataset
|
|
from .model.loss import GinkaLoss
|
|
from .model.model import GinkaModel
|
|
from shared.graph import DynamicGraphConverter
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
def validate():
|
|
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
|
|
model = GinkaModel()
|
|
|
|
minamo = MinamoModel(32)
|
|
minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
|
|
minamo.to(device)
|
|
|
|
# 准备数据集
|
|
val_dataset = GinkaDataset("ginka-eval.json")
|
|
val_loader = DataLoader(
|
|
val_dataset,
|
|
batch_size=32,
|
|
shuffle=True
|
|
)
|
|
|
|
converter = DynamicGraphConverter().to(device)
|
|
criterion = GinkaLoss(minamo, converter)
|
|
|
|
minamo.eval()
|
|
val_loss = 0
|
|
with torch.no_grad():
|
|
for batch in val_loader:
|
|
# 数据迁移到设备
|
|
target = batch["target"].to(device)
|
|
target_vision_feat = batch["target_vision_feat"].to(device)
|
|
target_topo_feat = batch["target_topo_feat"].to(device)
|
|
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
|
|
# 前向传播
|
|
output, _ = model(feat_vec)
|
|
map_matrix = torch.argmax(output, dim=1)
|
|
|
|
# 计算损失
|
|
loss = criterion(output, map_matrix, target, target_vision_feat, target_topo_feat)
|
|
total_loss += loss.item()
|
|
|
|
avg_val_loss = val_loss / len(val_loader)
|
|
tqdm.write(f"Validation::loss: {avg_val_loss:.6f}")
|
|
|
|
if __name__ == "__main__":
|
|
torch.set_num_threads(2)
|
|
validate()
|
|
|