chore: 调整超参数

This commit is contained in:
unanmed 2026-04-28 16:46:02 +08:00
parent a69403d6bf
commit 294e214431
3 changed files with 9 additions and 9 deletions

View File

@ -40,8 +40,8 @@ MAP_SIZE = 13 * 13
MAP_H = MAP_W = 13 MAP_H = MAP_W = 13
# VQ-VAE 超参(保持与 train_vq.py 一致) # VQ-VAE 超参(保持与 train_vq.py 一致)
VQ_L = 32 VQ_L = 2
VQ_K = 1 VQ_K = 4
VQ_D_Z = 128 VQ_D_Z = 128
VQ_D_MODEL= 192 VQ_D_MODEL= 192
VQ_NHEAD = 8 VQ_NHEAD = 8

View File

@ -46,8 +46,8 @@ LABEL_SMOOTHING = 0.0
WALL_MASK_RATIO = 0.8 WALL_MASK_RATIO = 0.8
# VQ-VAE 超参 # VQ-VAE 超参
VQ_L = 32 # summary token 数量(即 z 的序列长度) VQ_L = 2 # summary token 数量(即 z 的序列长度)
VQ_K = 1 # codebook 大小 VQ_K = 4 # codebook 大小
VQ_D_Z = 128 # codebook 嵌入维度 VQ_D_Z = 128 # codebook 嵌入维度
VQ_D_MODEL= 192 VQ_D_MODEL= 192
VQ_NHEAD = 8 VQ_NHEAD = 8

View File

@ -21,12 +21,12 @@ EVAL_DATA="ginka-eval.json"
# 阶段 0预训练 # 阶段 0预训练
P0_EPOCHS=50 P0_EPOCHS=50
P0_CHECKPOINT=5 P0_CHECKPOINT=10
P0_FINAL="result/pretrain/pretrain_final.pth" P0_FINAL="result/pretrain/pretrain_final.pth"
# 阶段 1冻结编码器热身 # 阶段 1冻结编码器热身
P1_EPOCHS=30 P1_EPOCHS=30
P1_CHECKPOINT=5 P1_CHECKPOINT=10
P1_FINAL="result/joint/warmup_final.pth" P1_FINAL="result/joint/warmup_final.pth"
# 阶段 2完整联合训练 # 阶段 2完整联合训练
@ -72,7 +72,7 @@ die() {
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
if [[ $START_PHASE -le 0 ]]; then if [[ $START_PHASE -le 0 ]]; then
log "阶段 0 / 3 VQ 编码器预训练 (epochs=${P0_EPOCHS})" log "阶段 0 / 3 VQ 编码器预训练 (epochs=${P0_EPOCHS})"
python -m ginka.train_pretrain \ python3 -u -m ginka.train_pretrain \
--train "$TRAIN_DATA" \ --train "$TRAIN_DATA" \
--validate "$EVAL_DATA" \ --validate "$EVAL_DATA" \
--epochs "$P0_EPOCHS" \ --epochs "$P0_EPOCHS" \
@ -90,7 +90,7 @@ fi
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
if [[ $START_PHASE -le 1 ]]; then if [[ $START_PHASE -le 1 ]]; then
log "阶段 1 / 3 MaskGIT 热身VQ 冻结) (epochs=${P1_EPOCHS})" log "阶段 1 / 3 MaskGIT 热身VQ 冻结) (epochs=${P1_EPOCHS})"
python -m ginka.train_vq \ python3 -u -m ginka.train_vq \
--train "$TRAIN_DATA" \ --train "$TRAIN_DATA" \
--validate "$EVAL_DATA" \ --validate "$EVAL_DATA" \
--resume True \ --resume True \
@ -116,7 +116,7 @@ fi
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
if [[ $START_PHASE -le 2 ]]; then if [[ $START_PHASE -le 2 ]]; then
log "阶段 2 / 3 完整联合训练 (epochs=${P2_EPOCHS})" log "阶段 2 / 3 完整联合训练 (epochs=${P2_EPOCHS})"
python -m ginka.train_vq \ python3 -u -m ginka.train_vq \
--train "$TRAIN_DATA" \ --train "$TRAIN_DATA" \
--validate "$EVAL_DATA" \ --validate "$EVAL_DATA" \
--resume True \ --resume True \