mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-22 11:01:12 +08:00
chore: 调整 batch_size
This commit is contained in:
parent
79cf3ab226
commit
83d6c31704
@ -51,7 +51,7 @@ from shared.image import matrix_to_image_cv
|
|||||||
# 26-28. 三种等级的怪物
|
# 26-28. 三种等级的怪物
|
||||||
# 29. 入口,不区分楼梯和箭头
|
# 29. 入口,不区分楼梯和箭头
|
||||||
|
|
||||||
BATCH_SIZE = 32
|
BATCH_SIZE = 128
|
||||||
|
|
||||||
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
||||||
os.makedirs("result", exist_ok=True)
|
os.makedirs("result", exist_ok=True)
|
||||||
@ -86,7 +86,7 @@ def train():
|
|||||||
dataset = GinkaRNNDataset(args.train, device)
|
dataset = GinkaRNNDataset(args.train, device)
|
||||||
dataset_val = GinkaRNNDataset(args.validate, device)
|
dataset_val = GinkaRNNDataset(args.validate, device)
|
||||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE)
|
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 4)
|
||||||
|
|
||||||
optimizer_ginka = optim.AdamW(vae.parameters(), lr=1e-4, weight_decay=1e-4)
|
optimizer_ginka = optim.AdamW(vae.parameters(), lr=1e-4, weight_decay=1e-4)
|
||||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)
|
scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user