From 83d6c317049767357bc1fec7102bceda34049fb3 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Mon, 19 Jan 2026 22:34:08 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E8=B0=83=E6=95=B4=20batch=5Fsize?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_vae.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ginka/train_vae.py b/ginka/train_vae.py index cec3143..dfb07f4 100644 --- a/ginka/train_vae.py +++ b/ginka/train_vae.py @@ -51,7 +51,7 @@ from shared.image import matrix_to_image_cv # 26-28. 三种等级的怪物 # 29. 入口,不区分楼梯和箭头 -BATCH_SIZE = 32 +BATCH_SIZE = 128 device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") os.makedirs("result", exist_ok=True) @@ -86,7 +86,7 @@ def train(): dataset = GinkaRNNDataset(args.train, device) dataset_val = GinkaRNNDataset(args.validate, device) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) - dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE) + dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 4) optimizer_ginka = optim.AdamW(vae.parameters(), lr=1e-4, weight_decay=1e-4) scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)