mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
Merge branch 'master' of github.com:unanmed/ginka-generator
This commit is contained in:
commit
caa6684675
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import json
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -11,6 +12,7 @@ from .model.loss import GinkaLoss
|
||||
from .model.model import GinkaModel
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
os.makedirs('result/ginka_img', exist_ok=True)
|
||||
|
||||
def blend_alpha(bg, fg, alpha):
|
||||
""" 使用 alpha 通道混合前景图块和背景图 """
|
||||
@ -35,8 +37,8 @@ def matrix_to_image_cv(map_matrix, tile_set, tile_size=32):
|
||||
x, y = col * tile_size, row * tile_size # 计算像素位置
|
||||
|
||||
# 先绘制地面(0)
|
||||
if 0 in tile_set:
|
||||
canvas[y:y+tile_size, x:x+tile_size] = tile_set[0][:, :, :3] # 仅填充 RGB
|
||||
if '0' in tile_set:
|
||||
canvas[y:y+tile_size, x:x+tile_size] = tile_set['0'][:, :, :3] # 仅填充 RGB
|
||||
|
||||
if tile_index == '11':
|
||||
if row == 0:
|
||||
@ -64,13 +66,16 @@ def matrix_to_image_cv(map_matrix, tile_set, tile_size=32):
|
||||
def validate():
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
|
||||
model = GinkaModel()
|
||||
state = torch.load("result/ginka_checkpoint/15.pth", map_location=device)["model_state"]
|
||||
model.load_state_dict(state)
|
||||
model.to(device)
|
||||
|
||||
minamo = MinamoModel(32)
|
||||
minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
|
||||
minamo.to(device)
|
||||
|
||||
# 准备数据集
|
||||
val_dataset = GinkaDataset("ginka-eval.json")
|
||||
val_dataset = GinkaDataset("ginka-eval.json", device, minamo)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=32,
|
||||
@ -80,9 +85,10 @@ def validate():
|
||||
criterion = GinkaLoss(minamo)
|
||||
|
||||
tile_dict = dict()
|
||||
val_output = dict()
|
||||
|
||||
for file in os.listdir('tiles'):
|
||||
name = os.path.basename(file)
|
||||
name = os.path.splitext(file)[0]
|
||||
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
|
||||
|
||||
minamo.eval()
|
||||
@ -90,28 +96,32 @@ def validate():
|
||||
val_loss = 0
|
||||
idx = 0
|
||||
with torch.no_grad():
|
||||
for batch in val_loader:
|
||||
for batch in tqdm(val_loader):
|
||||
# 数据迁移到设备
|
||||
target = batch["target"].to(device)
|
||||
target_vision_feat = batch["target_vision_feat"].to(device)
|
||||
target_topo_feat = batch["target_topo_feat"].to(device)
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
|
||||
# 前向传播
|
||||
output, _ = model(feat_vec)
|
||||
output = model(feat_vec)
|
||||
map_matrix = torch.argmax(output, dim=1)
|
||||
|
||||
for matrix in map_matrix[:].cpu().numpy():
|
||||
image = matrix_to_image_cv(matrix, tile_dict)
|
||||
cv2.imwrite(f"result/{idx}.png", image)
|
||||
for matrix in map_matrix[:].cpu():
|
||||
image = matrix_to_image_cv(matrix.numpy(), tile_dict)
|
||||
cv2.imwrite(f"result/ginka_img/{idx}.png", image)
|
||||
val_output[f"val_{idx}"] = matrix.tolist()
|
||||
idx += 1
|
||||
|
||||
# 计算损失
|
||||
loss = criterion(output, target, target_vision_feat, target_topo_feat)
|
||||
total_loss += loss.item()
|
||||
val_loss += loss.item()
|
||||
|
||||
avg_val_loss = val_loss / len(val_loader)
|
||||
tqdm.write(f"Validation::loss: {avg_val_loss:.6f}")
|
||||
|
||||
with open('result/ginka_val.json', 'w') as f:
|
||||
json.dump(val_output, f)
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(2)
|
||||
validate()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user