Merge branch 'master' of github.com:unanmed/ginka-generator

This commit is contained in:
unanmed 2025-03-21 00:05:14 +08:00
commit caa6684675

View File

@ -1,4 +1,5 @@
import os
import json
import cv2
import numpy as np
import torch
@ -11,6 +12,7 @@ from .model.loss import GinkaLoss
from .model.model import GinkaModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs('result/ginka_img', exist_ok=True)
def blend_alpha(bg, fg, alpha):
""" 使用 alpha 通道混合前景图块和背景图 """
@ -35,8 +37,8 @@ def matrix_to_image_cv(map_matrix, tile_set, tile_size=32):
x, y = col * tile_size, row * tile_size # 计算像素位置
# 先绘制地面0
if 0 in tile_set:
canvas[y:y+tile_size, x:x+tile_size] = tile_set[0][:, :, :3] # 仅填充 RGB
if '0' in tile_set:
canvas[y:y+tile_size, x:x+tile_size] = tile_set['0'][:, :, :3] # 仅填充 RGB
if tile_index == '11':
if row == 0:
@ -64,13 +66,16 @@ def matrix_to_image_cv(map_matrix, tile_set, tile_size=32):
def validate():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
model = GinkaModel()
state = torch.load("result/ginka_checkpoint/15.pth", map_location=device)["model_state"]
model.load_state_dict(state)
model.to(device)
minamo = MinamoModel(32)
minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
minamo.to(device)
# 准备数据集
val_dataset = GinkaDataset("ginka-eval.json")
val_dataset = GinkaDataset("ginka-eval.json", device, minamo)
val_loader = DataLoader(
val_dataset,
batch_size=32,
@ -80,9 +85,10 @@ def validate():
criterion = GinkaLoss(minamo)
tile_dict = dict()
val_output = dict()
for file in os.listdir('tiles'):
name = os.path.basename(file)
name = os.path.splitext(file)[0]
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
minamo.eval()
@ -90,28 +96,32 @@ def validate():
val_loss = 0
idx = 0
with torch.no_grad():
for batch in val_loader:
for batch in tqdm(val_loader):
# 数据迁移到设备
target = batch["target"].to(device)
target_vision_feat = batch["target_vision_feat"].to(device)
target_topo_feat = batch["target_topo_feat"].to(device)
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
# 前向传播
output, _ = model(feat_vec)
output = model(feat_vec)
map_matrix = torch.argmax(output, dim=1)
for matrix in map_matrix[:].cpu().numpy():
image = matrix_to_image_cv(matrix, tile_dict)
cv2.imwrite(f"result/{idx}.png", image)
for matrix in map_matrix[:].cpu():
image = matrix_to_image_cv(matrix.numpy(), tile_dict)
cv2.imwrite(f"result/ginka_img/{idx}.png", image)
val_output[f"val_{idx}"] = matrix.tolist()
idx += 1
# 计算损失
loss = criterion(output, target, target_vision_feat, target_topo_feat)
total_loss += loss.item()
val_loss += loss.item()
avg_val_loss = val_loss / len(val_loader)
tqdm.write(f"Validation::loss: {avg_val_loss:.6f}")
with open('result/ginka_val.json', 'w') as f:
json.dump(val_output, f)
if __name__ == "__main__":
torch.set_num_threads(2)
validate()