mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
chore: 调整参数
This commit is contained in:
parent
df23c891c6
commit
1237d45d95
@ -41,7 +41,7 @@ RAND_RATIO = 0.15
|
||||
# MaskGIT 生成设置
|
||||
USE_MASK_GIT_PREVIEW = True
|
||||
NUM_LAYERS = 4
|
||||
D_MODEL = 128
|
||||
D_MODEL = 192
|
||||
# Diffusion 生成设置
|
||||
NUM_LAYERS_DIFFUSION = 4
|
||||
D_MODEL_DIFFUSION = 128
|
||||
@ -85,6 +85,7 @@ def train():
|
||||
num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL,
|
||||
num_layers=NUM_LAYERS, d_model=D_MODEL
|
||||
).to(device)
|
||||
maskGIT.eval()
|
||||
model = GinkaHeatmapModel(
|
||||
T=T_DIFFUSION, heatmap_dim=HEATMAP_CHANNEL, d_model=D_MODEL_DIFFUSION,
|
||||
num_layers=NUM_LAYERS_DIFFUSION
|
||||
|
||||
@ -49,7 +49,7 @@ BLUR_MAX_SIZE = 9
|
||||
RAND_RATIO = 0.3
|
||||
MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机
|
||||
NUM_LAYERS = 4
|
||||
D_MODEL = 128
|
||||
D_MODEL = 192
|
||||
|
||||
device = torch.device(
|
||||
"cuda:1" if torch.cuda.is_available()
|
||||
@ -87,7 +87,7 @@ def train():
|
||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True)
|
||||
|
||||
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
|
||||
optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-2)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6)
|
||||
|
||||
# 用于生成图片
|
||||
|
||||
Loading…
Reference in New Issue
Block a user