From 294e214431800c3d33d49656a53fc7bbd43640c8 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Tue, 28 Apr 2026 16:46:02 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E8=B0=83=E6=95=B4=E8=B6=85=E5=8F=82?= =?UTF-8?q?=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_pretrain.py | 4 ++-- ginka/train_vq.py | 4 ++-- train_full.sh | 10 +++++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ginka/train_pretrain.py b/ginka/train_pretrain.py index a4c9dd4..e8c157a 100644 --- a/ginka/train_pretrain.py +++ b/ginka/train_pretrain.py @@ -40,8 +40,8 @@ MAP_SIZE = 13 * 13 MAP_H = MAP_W = 13 # VQ-VAE 超参(保持与 train_vq.py 一致) -VQ_L = 32 -VQ_K = 1 +VQ_L = 2 +VQ_K = 4 VQ_D_Z = 128 VQ_D_MODEL= 192 VQ_NHEAD = 8 diff --git a/ginka/train_vq.py b/ginka/train_vq.py index a891582..e17328c 100644 --- a/ginka/train_vq.py +++ b/ginka/train_vq.py @@ -46,8 +46,8 @@ LABEL_SMOOTHING = 0.0 WALL_MASK_RATIO = 0.8 # VQ-VAE 超参 -VQ_L = 32 # summary token 数量(即 z 的序列长度) -VQ_K = 1 # codebook 大小 +VQ_L = 2 # summary token 数量(即 z 的序列长度) +VQ_K = 4 # codebook 大小 VQ_D_Z = 128 # codebook 嵌入维度 VQ_D_MODEL= 192 VQ_NHEAD = 8 diff --git a/train_full.sh b/train_full.sh index 4a8409f..1e27d7f 100644 --- a/train_full.sh +++ b/train_full.sh @@ -21,12 +21,12 @@ EVAL_DATA="ginka-eval.json" # 阶段 0:预训练 P0_EPOCHS=50 -P0_CHECKPOINT=5 +P0_CHECKPOINT=10 P0_FINAL="result/pretrain/pretrain_final.pth" # 阶段 1:冻结编码器热身 P1_EPOCHS=30 -P1_CHECKPOINT=5 +P1_CHECKPOINT=10 P1_FINAL="result/joint/warmup_final.pth" # 阶段 2:完整联合训练 @@ -72,7 +72,7 @@ die() { # ------------------------------------------------------------------------------ if [[ $START_PHASE -le 0 ]]; then log "阶段 0 / 3 VQ 编码器预训练 (epochs=${P0_EPOCHS})" - python -m ginka.train_pretrain \ + python3 -u -m ginka.train_pretrain \ --train "$TRAIN_DATA" \ --validate "$EVAL_DATA" \ --epochs "$P0_EPOCHS" \ @@ -90,7 +90,7 @@ fi # ------------------------------------------------------------------------------ if [[ $START_PHASE -le 1 ]]; then log "阶段 1 / 3 MaskGIT 热身(VQ 冻结) (epochs=${P1_EPOCHS})" - python -m ginka.train_vq \ + python3 -u -m ginka.train_vq \ --train "$TRAIN_DATA" \ --validate "$EVAL_DATA" \ --resume True \ @@ -116,7 +116,7 @@ fi # ------------------------------------------------------------------------------ if [[ $START_PHASE -le 2 ]]; then log "阶段 2 / 3 完整联合训练 (epochs=${P2_EPOCHS})" - python -m ginka.train_vq \ + python3 -u -m ginka.train_vq \ --train "$TRAIN_DATA" \ --validate "$EVAL_DATA" \ --resume True \