diff --git a/ginka/train_heatmap.py b/ginka/train_heatmap.py index b12f1af..e9c5775 100644 --- a/ginka/train_heatmap.py +++ b/ginka/train_heatmap.py @@ -41,7 +41,7 @@ RAND_RATIO = 0.15 # MaskGIT 生成设置 USE_MASK_GIT_PREVIEW = True NUM_LAYERS = 4 -D_MODEL = 128 +D_MODEL = 192 # Diffusion 生成设置 NUM_LAYERS_DIFFUSION = 4 D_MODEL_DIFFUSION = 128 @@ -85,6 +85,7 @@ def train(): num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL, num_layers=NUM_LAYERS, d_model=D_MODEL ).to(device) + maskGIT.eval() model = GinkaHeatmapModel( T=T_DIFFUSION, heatmap_dim=HEATMAP_CHANNEL, d_model=D_MODEL_DIFFUSION, num_layers=NUM_LAYERS_DIFFUSION diff --git a/ginka/train_maskGIT.py b/ginka/train_maskGIT.py index f6938ee..7cbbe5b 100644 --- a/ginka/train_maskGIT.py +++ b/ginka/train_maskGIT.py @@ -49,7 +49,7 @@ BLUR_MAX_SIZE = 9 RAND_RATIO = 0.3 MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机 NUM_LAYERS = 4 -D_MODEL = 128 +D_MODEL = 192 device = torch.device( "cuda:1" if torch.cuda.is_available() @@ -87,7 +87,7 @@ def train(): dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True) - optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2) + optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-2) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6) # 用于生成图片