mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-22 02:44:51 +08:00
feat: 验证地图输出为图片
This commit is contained in:
parent
b2e1acb617
commit
d6b2b13ac8
@ -1,3 +1,6 @@
|
|||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch_geometric.loader import DataLoader
|
from torch_geometric.loader import DataLoader
|
||||||
@ -9,6 +12,55 @@ from .model.model import GinkaModel
|
|||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
def blend_alpha(bg, fg, alpha):
|
||||||
|
""" 使用 alpha 通道混合前景图块和背景图 """
|
||||||
|
for c in range(3): # 只混合 RGB 三个通道
|
||||||
|
bg[:, :, c] = (1 - alpha) * bg[:, :, c] + alpha * fg[:, :, c]
|
||||||
|
return bg
|
||||||
|
|
||||||
|
def matrix_to_image_cv(map_matrix, tile_set, tile_size=32):
|
||||||
|
"""
|
||||||
|
使用OpenCV加速的版本(适合大尺寸地图)
|
||||||
|
:param map_matrix: [H, W] 的numpy数组
|
||||||
|
:param tile_set: 字典 {tile_id: cv2图像(BGR格式)}
|
||||||
|
:param tile_size: 图块边长(像素)
|
||||||
|
"""
|
||||||
|
H, W = map_matrix.shape # 获取地图尺寸
|
||||||
|
canvas = np.zeros((H * tile_size, W * tile_size, 3), dtype=np.uint8) # 画布(黑色背景)
|
||||||
|
|
||||||
|
# 遍历地图矩阵
|
||||||
|
for row in range(H):
|
||||||
|
for col in range(W):
|
||||||
|
tile_index = str(map_matrix[row, col]) # 获取当前坐标的图块类型
|
||||||
|
x, y = col * tile_size, row * tile_size # 计算像素位置
|
||||||
|
|
||||||
|
# 先绘制地面(0)
|
||||||
|
if 0 in tile_set:
|
||||||
|
canvas[y:y+tile_size, x:x+tile_size] = tile_set[0][:, :, :3] # 仅填充 RGB
|
||||||
|
|
||||||
|
if tile_index == '11':
|
||||||
|
if row == 0:
|
||||||
|
tile_index = '11_2'
|
||||||
|
elif row == W - 1:
|
||||||
|
tile_index = '11_4'
|
||||||
|
elif col == 0:
|
||||||
|
tile_index = '11_1'
|
||||||
|
elif col == H - 1:
|
||||||
|
tile_index = '11_3'
|
||||||
|
|
||||||
|
# 叠加其他透明图块
|
||||||
|
if tile_index in tile_set and tile_index != 0:
|
||||||
|
tile_rgba = tile_set[tile_index]
|
||||||
|
tile_rgb = tile_rgba[:, :, :3] # 提取 RGB
|
||||||
|
alpha = tile_rgba[:, :, 3] / 255.0 # 归一化 alpha
|
||||||
|
|
||||||
|
# 混合当前图块到背景
|
||||||
|
canvas[y:y+tile_size, x:x+tile_size] = blend_alpha(
|
||||||
|
canvas[y:y+tile_size, x:x+tile_size], tile_rgb, alpha
|
||||||
|
)
|
||||||
|
|
||||||
|
return canvas
|
||||||
|
|
||||||
def validate():
|
def validate():
|
||||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
|
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
|
||||||
model = GinkaModel()
|
model = GinkaModel()
|
||||||
@ -27,9 +79,16 @@ def validate():
|
|||||||
|
|
||||||
criterion = GinkaLoss(minamo)
|
criterion = GinkaLoss(minamo)
|
||||||
|
|
||||||
|
tile_dict = dict()
|
||||||
|
|
||||||
|
for file in os.listdir('tiles'):
|
||||||
|
name = os.path.basename(file)
|
||||||
|
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
|
||||||
|
|
||||||
minamo.eval()
|
minamo.eval()
|
||||||
model.eval()
|
model.eval()
|
||||||
val_loss = 0
|
val_loss = 0
|
||||||
|
idx = 0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch in val_loader:
|
for batch in val_loader:
|
||||||
# 数据迁移到设备
|
# 数据迁移到设备
|
||||||
@ -41,6 +100,11 @@ def validate():
|
|||||||
output, _ = model(feat_vec)
|
output, _ = model(feat_vec)
|
||||||
map_matrix = torch.argmax(output, dim=1)
|
map_matrix = torch.argmax(output, dim=1)
|
||||||
|
|
||||||
|
for matrix in map_matrix[:].cpu().numpy():
|
||||||
|
image = matrix_to_image_cv(matrix, tile_dict)
|
||||||
|
cv2.imwrite(f"result/{idx}.png", image)
|
||||||
|
idx += 1
|
||||||
|
|
||||||
# 计算损失
|
# 计算损失
|
||||||
loss = criterion(output, target, target_vision_feat, target_topo_feat)
|
loss = criterion(output, target, target_vision_feat, target_topo_feat)
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
|
|||||||
@ -122,8 +122,8 @@ def train():
|
|||||||
graph1 = graph1.to(device)
|
graph1 = graph1.to(device)
|
||||||
graph2 = graph2.to(device)
|
graph2 = graph2.to(device)
|
||||||
|
|
||||||
vision_feat1, topo_feat1 = model(map1, graph1)
|
vision_feat1, topo_feat1 = model(map1_val, graph1)
|
||||||
vision_feat2, topo_feat2 = model(map2, graph2)
|
vision_feat2, topo_feat2 = model(map2_val, graph2)
|
||||||
|
|
||||||
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
||||||
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user