feat: 验证地图输出为图片

This commit is contained in:
unanmed 2025-03-19 22:15:19 +08:00
parent b2e1acb617
commit d6b2b13ac8
2 changed files with 66 additions and 2 deletions

View File

@ -1,3 +1,6 @@
import os
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
@ -9,6 +12,55 @@ from .model.model import GinkaModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def blend_alpha(bg, fg, alpha):
""" 使用 alpha 通道混合前景图块和背景图 """
for c in range(3): # 只混合 RGB 三个通道
bg[:, :, c] = (1 - alpha) * bg[:, :, c] + alpha * fg[:, :, c]
return bg
def matrix_to_image_cv(map_matrix, tile_set, tile_size=32):
"""
使用OpenCV加速的版本适合大尺寸地图
:param map_matrix: [H, W] 的numpy数组
:param tile_set: 字典 {tile_id: cv2图像BGR格式}
:param tile_size: 图块边长像素
"""
H, W = map_matrix.shape # 获取地图尺寸
canvas = np.zeros((H * tile_size, W * tile_size, 3), dtype=np.uint8) # 画布(黑色背景)
# 遍历地图矩阵
for row in range(H):
for col in range(W):
tile_index = str(map_matrix[row, col]) # 获取当前坐标的图块类型
x, y = col * tile_size, row * tile_size # 计算像素位置
# 先绘制地面0
if 0 in tile_set:
canvas[y:y+tile_size, x:x+tile_size] = tile_set[0][:, :, :3] # 仅填充 RGB
if tile_index == '11':
if row == 0:
tile_index = '11_2'
elif row == W - 1:
tile_index = '11_4'
elif col == 0:
tile_index = '11_1'
elif col == H - 1:
tile_index = '11_3'
# 叠加其他透明图块
if tile_index in tile_set and tile_index != 0:
tile_rgba = tile_set[tile_index]
tile_rgb = tile_rgba[:, :, :3] # 提取 RGB
alpha = tile_rgba[:, :, 3] / 255.0 # 归一化 alpha
# 混合当前图块到背景
canvas[y:y+tile_size, x:x+tile_size] = blend_alpha(
canvas[y:y+tile_size, x:x+tile_size], tile_rgb, alpha
)
return canvas
def validate():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
model = GinkaModel()
@ -27,9 +79,16 @@ def validate():
criterion = GinkaLoss(minamo)
tile_dict = dict()
for file in os.listdir('tiles'):
name = os.path.basename(file)
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
minamo.eval()
model.eval()
val_loss = 0
idx = 0
with torch.no_grad():
for batch in val_loader:
# 数据迁移到设备
@ -41,6 +100,11 @@ def validate():
output, _ = model(feat_vec)
map_matrix = torch.argmax(output, dim=1)
for matrix in map_matrix[:].cpu().numpy():
image = matrix_to_image_cv(matrix, tile_dict)
cv2.imwrite(f"result/{idx}.png", image)
idx += 1
# 计算损失
loss = criterion(output, target, target_vision_feat, target_topo_feat)
total_loss += loss.item()

View File

@ -122,8 +122,8 @@ def train():
graph1 = graph1.to(device)
graph2 = graph2.to(device)
vision_feat1, topo_feat1 = model(map1, graph1)
vision_feat2, topo_feat2 = model(map2, graph2)
vision_feat1, topo_feat1 = model(map1_val, graph1)
vision_feat2, topo_feat2 = model(map2_val, graph2)
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)