diff --git a/ginka/validate.py b/ginka/validate.py index 71eaf59..428b048 100644 --- a/ginka/validate.py +++ b/ginka/validate.py @@ -1,3 +1,6 @@ +import os +import cv2 +import numpy as np import torch import torch.nn.functional as F from torch_geometric.loader import DataLoader @@ -9,6 +12,55 @@ from .model.model import GinkaModel device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +def blend_alpha(bg, fg, alpha): + """ 使用 alpha 通道混合前景图块和背景图 """ + for c in range(3): # 只混合 RGB 三个通道 + bg[:, :, c] = (1 - alpha) * bg[:, :, c] + alpha * fg[:, :, c] + return bg + +def matrix_to_image_cv(map_matrix, tile_set, tile_size=32): + """ + 使用OpenCV加速的版本(适合大尺寸地图) + :param map_matrix: [H, W] 的numpy数组 + :param tile_set: 字典 {tile_id: cv2图像(BGR格式)} + :param tile_size: 图块边长(像素) + """ + H, W = map_matrix.shape # 获取地图尺寸 + canvas = np.zeros((H * tile_size, W * tile_size, 3), dtype=np.uint8) # 画布(黑色背景) + + # 遍历地图矩阵 + for row in range(H): + for col in range(W): + tile_index = str(map_matrix[row, col]) # 获取当前坐标的图块类型 + x, y = col * tile_size, row * tile_size # 计算像素位置 + + # 先绘制地面(0) + if 0 in tile_set: + canvas[y:y+tile_size, x:x+tile_size] = tile_set[0][:, :, :3] # 仅填充 RGB + + if tile_index == '11': + if row == 0: + tile_index = '11_2' + elif row == W - 1: + tile_index = '11_4' + elif col == 0: + tile_index = '11_1' + elif col == H - 1: + tile_index = '11_3' + + # 叠加其他透明图块 + if tile_index in tile_set and tile_index != 0: + tile_rgba = tile_set[tile_index] + tile_rgb = tile_rgba[:, :, :3] # 提取 RGB + alpha = tile_rgba[:, :, 3] / 255.0 # 归一化 alpha + + # 混合当前图块到背景 + canvas[y:y+tile_size, x:x+tile_size] = blend_alpha( + canvas[y:y+tile_size, x:x+tile_size], tile_rgb, alpha + ) + + return canvas + def validate(): print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.") model = GinkaModel() @@ -27,9 +79,16 @@ def validate(): criterion = GinkaLoss(minamo) + tile_dict = dict() + + for file in os.listdir('tiles'): + name = os.path.basename(file) + tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED) + minamo.eval() model.eval() val_loss = 0 + idx = 0 with torch.no_grad(): for batch in val_loader: # 数据迁移到设备 @@ -41,6 +100,11 @@ def validate(): output, _ = model(feat_vec) map_matrix = torch.argmax(output, dim=1) + for matrix in map_matrix[:].cpu().numpy(): + image = matrix_to_image_cv(matrix, tile_dict) + cv2.imwrite(f"result/{idx}.png", image) + idx += 1 + # 计算损失 loss = criterion(output, target, target_vision_feat, target_topo_feat) total_loss += loss.item() diff --git a/minamo/train.py b/minamo/train.py index dbe2775..153ac90 100644 --- a/minamo/train.py +++ b/minamo/train.py @@ -122,8 +122,8 @@ def train(): graph1 = graph1.to(device) graph2 = graph2.to(device) - vision_feat1, topo_feat1 = model(map1, graph1) - vision_feat2, topo_feat2 = model(map2, graph2) + vision_feat1, topo_feat1 = model(map1_val, graph1) + vision_feat2, topo_feat2 = model(map2_val, graph2) vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1) topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)