mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-19 08:21:11 +08:00
chore: 调整参数
This commit is contained in:
parent
df23c891c6
commit
1237d45d95
@ -41,7 +41,7 @@ RAND_RATIO = 0.15
|
|||||||
# MaskGIT 生成设置
|
# MaskGIT 生成设置
|
||||||
USE_MASK_GIT_PREVIEW = True
|
USE_MASK_GIT_PREVIEW = True
|
||||||
NUM_LAYERS = 4
|
NUM_LAYERS = 4
|
||||||
D_MODEL = 128
|
D_MODEL = 192
|
||||||
# Diffusion 生成设置
|
# Diffusion 生成设置
|
||||||
NUM_LAYERS_DIFFUSION = 4
|
NUM_LAYERS_DIFFUSION = 4
|
||||||
D_MODEL_DIFFUSION = 128
|
D_MODEL_DIFFUSION = 128
|
||||||
@ -85,6 +85,7 @@ def train():
|
|||||||
num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL,
|
num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL,
|
||||||
num_layers=NUM_LAYERS, d_model=D_MODEL
|
num_layers=NUM_LAYERS, d_model=D_MODEL
|
||||||
).to(device)
|
).to(device)
|
||||||
|
maskGIT.eval()
|
||||||
model = GinkaHeatmapModel(
|
model = GinkaHeatmapModel(
|
||||||
T=T_DIFFUSION, heatmap_dim=HEATMAP_CHANNEL, d_model=D_MODEL_DIFFUSION,
|
T=T_DIFFUSION, heatmap_dim=HEATMAP_CHANNEL, d_model=D_MODEL_DIFFUSION,
|
||||||
num_layers=NUM_LAYERS_DIFFUSION
|
num_layers=NUM_LAYERS_DIFFUSION
|
||||||
|
|||||||
@ -49,7 +49,7 @@ BLUR_MAX_SIZE = 9
|
|||||||
RAND_RATIO = 0.3
|
RAND_RATIO = 0.3
|
||||||
MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机
|
MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机
|
||||||
NUM_LAYERS = 4
|
NUM_LAYERS = 4
|
||||||
D_MODEL = 128
|
D_MODEL = 192
|
||||||
|
|
||||||
device = torch.device(
|
device = torch.device(
|
||||||
"cuda:1" if torch.cuda.is_available()
|
"cuda:1" if torch.cuda.is_available()
|
||||||
@ -87,7 +87,7 @@ def train():
|
|||||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True)
|
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True)
|
||||||
|
|
||||||
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
|
optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-2)
|
||||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6)
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6)
|
||||||
|
|
||||||
# 用于生成图片
|
# 用于生成图片
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user