mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
chore: 调整超参数
This commit is contained in:
parent
a69403d6bf
commit
294e214431
@ -40,8 +40,8 @@ MAP_SIZE = 13 * 13
|
||||
MAP_H = MAP_W = 13
|
||||
|
||||
# VQ-VAE 超参(保持与 train_vq.py 一致)
|
||||
VQ_L = 32
|
||||
VQ_K = 1
|
||||
VQ_L = 2
|
||||
VQ_K = 4
|
||||
VQ_D_Z = 128
|
||||
VQ_D_MODEL= 192
|
||||
VQ_NHEAD = 8
|
||||
|
||||
@ -46,8 +46,8 @@ LABEL_SMOOTHING = 0.0
|
||||
WALL_MASK_RATIO = 0.8
|
||||
|
||||
# VQ-VAE 超参
|
||||
VQ_L = 32 # summary token 数量(即 z 的序列长度)
|
||||
VQ_K = 1 # codebook 大小
|
||||
VQ_L = 2 # summary token 数量(即 z 的序列长度)
|
||||
VQ_K = 4 # codebook 大小
|
||||
VQ_D_Z = 128 # codebook 嵌入维度
|
||||
VQ_D_MODEL= 192
|
||||
VQ_NHEAD = 8
|
||||
|
||||
@ -21,12 +21,12 @@ EVAL_DATA="ginka-eval.json"
|
||||
|
||||
# 阶段 0:预训练
|
||||
P0_EPOCHS=50
|
||||
P0_CHECKPOINT=5
|
||||
P0_CHECKPOINT=10
|
||||
P0_FINAL="result/pretrain/pretrain_final.pth"
|
||||
|
||||
# 阶段 1:冻结编码器热身
|
||||
P1_EPOCHS=30
|
||||
P1_CHECKPOINT=5
|
||||
P1_CHECKPOINT=10
|
||||
P1_FINAL="result/joint/warmup_final.pth"
|
||||
|
||||
# 阶段 2:完整联合训练
|
||||
@ -72,7 +72,7 @@ die() {
|
||||
# ------------------------------------------------------------------------------
|
||||
if [[ $START_PHASE -le 0 ]]; then
|
||||
log "阶段 0 / 3 VQ 编码器预训练 (epochs=${P0_EPOCHS})"
|
||||
python -m ginka.train_pretrain \
|
||||
python3 -u -m ginka.train_pretrain \
|
||||
--train "$TRAIN_DATA" \
|
||||
--validate "$EVAL_DATA" \
|
||||
--epochs "$P0_EPOCHS" \
|
||||
@ -90,7 +90,7 @@ fi
|
||||
# ------------------------------------------------------------------------------
|
||||
if [[ $START_PHASE -le 1 ]]; then
|
||||
log "阶段 1 / 3 MaskGIT 热身(VQ 冻结) (epochs=${P1_EPOCHS})"
|
||||
python -m ginka.train_vq \
|
||||
python3 -u -m ginka.train_vq \
|
||||
--train "$TRAIN_DATA" \
|
||||
--validate "$EVAL_DATA" \
|
||||
--resume True \
|
||||
@ -116,7 +116,7 @@ fi
|
||||
# ------------------------------------------------------------------------------
|
||||
if [[ $START_PHASE -le 2 ]]; then
|
||||
log "阶段 2 / 3 完整联合训练 (epochs=${P2_EPOCHS})"
|
||||
python -m ginka.train_vq \
|
||||
python3 -u -m ginka.train_vq \
|
||||
--train "$TRAIN_DATA" \
|
||||
--validate "$EVAL_DATA" \
|
||||
--resume True \
|
||||
|
||||
Loading…
Reference in New Issue
Block a user