diff --git a/ginka/train_heatmap.py b/ginka/train_heatmap.py index 05cdd27..06898d8 100644 --- a/ginka/train_heatmap.py +++ b/ginka/train_heatmap.py @@ -26,8 +26,8 @@ from .utils import nms_sampling # 0. 墙壁热力图, 1. 怪物热力图, 2. 资源热力图, 3. 血瓶热力图, 4. 宝石热力图, 5. 钥匙热力图 # 6. 道具热力图, 7. 入口热力图, 8. 门热力图 -BATCH_SIZE = 8 -VAL_BATCH_DIVIDER = 8 +BATCH_SIZE = 128 +VAL_BATCH_DIVIDER = 64 NUM_CLASSES = 16 MASK_TOKEN = 15 GENERATE_STEP = 8