diff --git a/ginka/train_maskGIT.py b/ginka/train_maskGIT.py index 67e56cf..8b2a636 100644 --- a/ginka/train_maskGIT.py +++ b/ginka/train_maskGIT.py @@ -43,7 +43,7 @@ MASK_TOKEN = 15 GENERATE_STEP = 8 MAP_SIZE = 13 * 13 HEATMAP_CHANNEL = 9 -LABEL_SMOOTHING = 0.1 +LABEL_SMOOTHING = 0 RAND_RATIO = 0.1 MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机 @@ -85,7 +85,7 @@ def train(): optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2) # 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习 - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-5) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6) # 用于生成图片 tile_dict = dict()