From 1237d45d95f3cc9cc7ad75967a4fdd15166d974c Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Fri, 10 Apr 2026 12:41:53 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E8=B0=83=E6=95=B4=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_heatmap.py | 3 ++- ginka/train_maskGIT.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ginka/train_heatmap.py b/ginka/train_heatmap.py index b12f1af..e9c5775 100644 --- a/ginka/train_heatmap.py +++ b/ginka/train_heatmap.py @@ -41,7 +41,7 @@ RAND_RATIO = 0.15 # MaskGIT 生成设置 USE_MASK_GIT_PREVIEW = True NUM_LAYERS = 4 -D_MODEL = 128 +D_MODEL = 192 # Diffusion 生成设置 NUM_LAYERS_DIFFUSION = 4 D_MODEL_DIFFUSION = 128 @@ -85,6 +85,7 @@ def train(): num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL, num_layers=NUM_LAYERS, d_model=D_MODEL ).to(device) + maskGIT.eval() model = GinkaHeatmapModel( T=T_DIFFUSION, heatmap_dim=HEATMAP_CHANNEL, d_model=D_MODEL_DIFFUSION, num_layers=NUM_LAYERS_DIFFUSION diff --git a/ginka/train_maskGIT.py b/ginka/train_maskGIT.py index f6938ee..7cbbe5b 100644 --- a/ginka/train_maskGIT.py +++ b/ginka/train_maskGIT.py @@ -49,7 +49,7 @@ BLUR_MAX_SIZE = 9 RAND_RATIO = 0.3 MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机 NUM_LAYERS = 4 -D_MODEL = 128 +D_MODEL = 192 device = torch.device( "cuda:1" if torch.cuda.is_available() @@ -87,7 +87,7 @@ def train(): dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True) - optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2) + optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-2) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6) # 用于生成图片