From c9c52109ed96a01391fd10ce9381d5319fb95b39 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Fri, 21 Mar 2025 13:25:12 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=8C=87=E5=AE=9A=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E9=9B=86=E5=92=8C=E9=AA=8C=E8=AF=81=E9=9B=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train.py | 4 ++-- minamo/train.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ginka/train.py b/ginka/train.py index f86e75a..b4ec39f 100644 --- a/ginka/train.py +++ b/ginka/train.py @@ -34,8 +34,8 @@ def train(): # param.requires_grad = False # 准备数据集 - dataset = GinkaDataset("ginka-dataset.json", device, minamo) - dataset_val = GinkaDataset("ginka-eval.json", device, minamo) + dataset = GinkaDataset(args.train, device, minamo) + dataset_val = GinkaDataset(args.validate, device, minamo) dataloader = DataLoader( dataset, batch_size=32, diff --git a/minamo/train.py b/minamo/train.py index bc3a1d9..cb90c9c 100644 --- a/minamo/train.py +++ b/minamo/train.py @@ -41,8 +41,8 @@ def train(): model.to(device) # 准备数据集 - dataset = MinamoDataset("minamo-dataset.json") - val_dataset = MinamoDataset("minamo-eval.json") + dataset = MinamoDataset(args.train) + val_dataset = MinamoDataset(args.validate) dataloader = DataLoader( dataset, batch_size=64,