mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-17 15:01:10 +08:00
feat: 验证地图输出为图片
This commit is contained in:
parent
b2e1acb617
commit
d6b2b13ac8
@ -1,3 +1,6 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.loader import DataLoader
|
||||
@ -9,6 +12,55 @@ from .model.model import GinkaModel
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def blend_alpha(bg, fg, alpha):
|
||||
""" 使用 alpha 通道混合前景图块和背景图 """
|
||||
for c in range(3): # 只混合 RGB 三个通道
|
||||
bg[:, :, c] = (1 - alpha) * bg[:, :, c] + alpha * fg[:, :, c]
|
||||
return bg
|
||||
|
||||
def matrix_to_image_cv(map_matrix, tile_set, tile_size=32):
|
||||
"""
|
||||
使用OpenCV加速的版本(适合大尺寸地图)
|
||||
:param map_matrix: [H, W] 的numpy数组
|
||||
:param tile_set: 字典 {tile_id: cv2图像(BGR格式)}
|
||||
:param tile_size: 图块边长(像素)
|
||||
"""
|
||||
H, W = map_matrix.shape # 获取地图尺寸
|
||||
canvas = np.zeros((H * tile_size, W * tile_size, 3), dtype=np.uint8) # 画布(黑色背景)
|
||||
|
||||
# 遍历地图矩阵
|
||||
for row in range(H):
|
||||
for col in range(W):
|
||||
tile_index = str(map_matrix[row, col]) # 获取当前坐标的图块类型
|
||||
x, y = col * tile_size, row * tile_size # 计算像素位置
|
||||
|
||||
# 先绘制地面(0)
|
||||
if 0 in tile_set:
|
||||
canvas[y:y+tile_size, x:x+tile_size] = tile_set[0][:, :, :3] # 仅填充 RGB
|
||||
|
||||
if tile_index == '11':
|
||||
if row == 0:
|
||||
tile_index = '11_2'
|
||||
elif row == W - 1:
|
||||
tile_index = '11_4'
|
||||
elif col == 0:
|
||||
tile_index = '11_1'
|
||||
elif col == H - 1:
|
||||
tile_index = '11_3'
|
||||
|
||||
# 叠加其他透明图块
|
||||
if tile_index in tile_set and tile_index != 0:
|
||||
tile_rgba = tile_set[tile_index]
|
||||
tile_rgb = tile_rgba[:, :, :3] # 提取 RGB
|
||||
alpha = tile_rgba[:, :, 3] / 255.0 # 归一化 alpha
|
||||
|
||||
# 混合当前图块到背景
|
||||
canvas[y:y+tile_size, x:x+tile_size] = blend_alpha(
|
||||
canvas[y:y+tile_size, x:x+tile_size], tile_rgb, alpha
|
||||
)
|
||||
|
||||
return canvas
|
||||
|
||||
def validate():
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
|
||||
model = GinkaModel()
|
||||
@ -27,9 +79,16 @@ def validate():
|
||||
|
||||
criterion = GinkaLoss(minamo)
|
||||
|
||||
tile_dict = dict()
|
||||
|
||||
for file in os.listdir('tiles'):
|
||||
name = os.path.basename(file)
|
||||
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
|
||||
|
||||
minamo.eval()
|
||||
model.eval()
|
||||
val_loss = 0
|
||||
idx = 0
|
||||
with torch.no_grad():
|
||||
for batch in val_loader:
|
||||
# 数据迁移到设备
|
||||
@ -41,6 +100,11 @@ def validate():
|
||||
output, _ = model(feat_vec)
|
||||
map_matrix = torch.argmax(output, dim=1)
|
||||
|
||||
for matrix in map_matrix[:].cpu().numpy():
|
||||
image = matrix_to_image_cv(matrix, tile_dict)
|
||||
cv2.imwrite(f"result/{idx}.png", image)
|
||||
idx += 1
|
||||
|
||||
# 计算损失
|
||||
loss = criterion(output, target, target_vision_feat, target_topo_feat)
|
||||
total_loss += loss.item()
|
||||
|
||||
@ -122,8 +122,8 @@ def train():
|
||||
graph1 = graph1.to(device)
|
||||
graph2 = graph2.to(device)
|
||||
|
||||
vision_feat1, topo_feat1 = model(map1, graph1)
|
||||
vision_feat2, topo_feat2 = model(map2, graph2)
|
||||
vision_feat1, topo_feat1 = model(map1_val, graph1)
|
||||
vision_feat2, topo_feat2 = model(map2_val, graph2)
|
||||
|
||||
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
||||
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user