From f49705a556b5428bc6577be41d7e7428f249eb21 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Thu, 20 Mar 2025 23:14:47 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20GINKA=20=E9=AA=8C=E8=AF=81=E8=84=9A?= =?UTF-8?q?=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/validate.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/ginka/validate.py b/ginka/validate.py index 428b048..60318d8 100644 --- a/ginka/validate.py +++ b/ginka/validate.py @@ -1,4 +1,5 @@ import os +import json import cv2 import numpy as np import torch @@ -11,6 +12,7 @@ from .model.loss import GinkaLoss from .model.model import GinkaModel device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +os.makedirs('result/ginka_img', exist_ok=True) def blend_alpha(bg, fg, alpha): """ 使用 alpha 通道混合前景图块和背景图 """ @@ -35,8 +37,8 @@ def matrix_to_image_cv(map_matrix, tile_set, tile_size=32): x, y = col * tile_size, row * tile_size # 计算像素位置 # 先绘制地面(0) - if 0 in tile_set: - canvas[y:y+tile_size, x:x+tile_size] = tile_set[0][:, :, :3] # 仅填充 RGB + if '0' in tile_set: + canvas[y:y+tile_size, x:x+tile_size] = tile_set['0'][:, :, :3] # 仅填充 RGB if tile_index == '11': if row == 0: @@ -64,13 +66,16 @@ def matrix_to_image_cv(map_matrix, tile_set, tile_size=32): def validate(): print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.") model = GinkaModel() + state = torch.load("result/ginka_checkpoint/15.pth", map_location=device)["model_state"] + model.load_state_dict(state) + model.to(device) minamo = MinamoModel(32) minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"]) minamo.to(device) # 准备数据集 - val_dataset = GinkaDataset("ginka-eval.json") + val_dataset = GinkaDataset("ginka-eval.json", device, minamo) val_loader = DataLoader( val_dataset, batch_size=32, @@ -80,9 +85,10 @@ def validate(): criterion = GinkaLoss(minamo) tile_dict = dict() + val_output = dict() for file in os.listdir('tiles'): - name = os.path.basename(file) + name = os.path.splitext(file)[0] tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED) minamo.eval() @@ -90,28 +96,32 @@ def validate(): val_loss = 0 idx = 0 with torch.no_grad(): - for batch in val_loader: + for batch in tqdm(val_loader): # 数据迁移到设备 target = batch["target"].to(device) target_vision_feat = batch["target_vision_feat"].to(device) target_topo_feat = batch["target_topo_feat"].to(device) feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device) # 前向传播 - output, _ = model(feat_vec) + output = model(feat_vec) map_matrix = torch.argmax(output, dim=1) - for matrix in map_matrix[:].cpu().numpy(): - image = matrix_to_image_cv(matrix, tile_dict) - cv2.imwrite(f"result/{idx}.png", image) + for matrix in map_matrix[:].cpu(): + image = matrix_to_image_cv(matrix.numpy(), tile_dict) + cv2.imwrite(f"result/ginka_img/{idx}.png", image) + val_output[f"val_{idx}"] = matrix.tolist() idx += 1 # 计算损失 loss = criterion(output, target, target_vision_feat, target_topo_feat) - total_loss += loss.item() + val_loss += loss.item() avg_val_loss = val_loss / len(val_loader) tqdm.write(f"Validation::loss: {avg_val_loss:.6f}") + with open('result/ginka_val.json', 'w') as f: + json.dump(val_output, f) + if __name__ == "__main__": torch.set_num_threads(2) validate()