diff --git a/ginka/train_heatmap.py b/ginka/train_heatmap.py index 81d9601..46a213f 100644 --- a/ginka/train_heatmap.py +++ b/ginka/train_heatmap.py @@ -195,6 +195,8 @@ def train(): generated_img = matrix_to_image_cv(map.view(B, H, W)[0].cpu().numpy(), tile_dict) cv2.imwrite(f"result/final_img/{idx}.png", generated_img) + idx += 1 + # 3. 完全随机生成五张图 if args.use_maskgit: for i in range(0, 5):