diff --git a/ginka/train_vq.py b/ginka/train_vq.py index fe0bf0b..3035a9d 100644 --- a/ginka/train_vq.py +++ b/ginka/train_vq.py @@ -89,8 +89,8 @@ def parse_arguments(): parser.add_argument("--resume", type=bool, default=False) parser.add_argument("--state", type=str, default="result/joint/joint-10.pth", help="续训时加载的检查点路径") - parser.add_argument("--train", type=str, default="data/ginka-dataset.json") - parser.add_argument("--validate", type=str, default="data/ginka-eval.json") + parser.add_argument("--train", type=str, default="ginka-dataset.json") + parser.add_argument("--validate", type=str, default="ginka-eval.json") parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--checkpoint", type=int, default=5, help="每隔多少 epoch 保存检查点并验证")