diff --git a/ginka/train_vq.py b/ginka/train_vq.py index f67e70f..a226a8d 100644 --- a/ginka/train_vq.py +++ b/ginka/train_vq.py @@ -174,12 +174,20 @@ def make_map_image(map_flat: torch.Tensor, tile_dict: dict) -> np.ndarray: def hstack_images(imgs: list, gap: int = 4, color=(255, 255, 255)) -> np.ndarray: - """将多张等高图片横向拼接,之间插入白色竖线。""" - H = imgs[0].shape[0] - vline = np.full((H, gap, 3), color, dtype=np.uint8) - result = imgs[0] + """将多张图片横向拼接,之间插入竖线;高度不一致时底部补齐背景色。""" + max_h = max(img.shape[0] for img in imgs) + + def _pad_h(img): + dh = max_h - img.shape[0] + if dh == 0: + return img + pad = np.full((dh, img.shape[1], 3), color, dtype=np.uint8) + return np.concatenate([img, pad], axis=0) + + vline = np.full((max_h, gap, 3), color, dtype=np.uint8) + result = _pad_h(imgs[0]) for img in imgs[1:]: - result = np.concatenate([result, vline, img], axis=1) + result = np.concatenate([result, vline, _pad_h(img)], axis=1) return result