From ea57bbde3a733de64bdb1d1dc810bfb6f080aa41 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Wed, 6 May 2026 20:44:00 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E4=BF=AE=E6=94=B9=E8=B6=85=E5=8F=82?= =?UTF-8?q?=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_pretrain_split.py | 12 ++++++------ ginka/train_vq.py | 8 ++++---- train_full.sh | 8 ++++---- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/ginka/train_pretrain_split.py b/ginka/train_pretrain_split.py index c7e055c..79fd350 100644 --- a/ginka/train_pretrain_split.py +++ b/ginka/train_pretrain_split.py @@ -42,25 +42,25 @@ FOCAL_GAMMA = 2.0 # 通道 1:空间骨架(floor+wall) CH1_KEEP = {0, 1} # 编码器输入保留的 tile CH1_LOSS = {0, 1} # 损失计算范围(仅 wall) -CH1_D_MODEL = 128 +CH1_D_MODEL = 64 CH1_NHEAD = 8 # 通道 2:关卡门控 CH2_KEEP = {0, 1, 2, 9, 10} -CH2_LOSS = {2, 9, 10} -CH2_D_MODEL = 128 +CH2_LOSS = {0, 1, 2, 9, 10} +CH2_D_MODEL = 64 CH2_NHEAD = 8 # 通道 3:收集资源 CH3_KEEP = None # 完整地图,无需切片 -CH3_LOSS = {3, 4, 5, 6, 7, 8} -CH3_D_MODEL = 128 +CH3_LOSS = {0, 1, 2, 3, 9, 10} +CH3_D_MODEL = 64 CH3_NHEAD = 8 # 三路共用的 VQ 超参 VQ_L = 2 VQ_K = 8 -VQ_D_Z = 128 +VQ_D_Z = 64 VQ_LAYERS = 3 VQ_DIM_FF = 512 VQ_BETA = 0.5 # commit loss 权重 diff --git a/ginka/train_vq.py b/ginka/train_vq.py index 3d16f39..e6c0e0f 100644 --- a/ginka/train_vq.py +++ b/ginka/train_vq.py @@ -48,14 +48,14 @@ WALL_MASK_RATIO = 0.8 # VQ-VAE 公共超参(三路编码器共用,方案 B 三通道分拆) VQ_L = 2 # 每路码字序列长度(三路合计 L1+L2+L3 = 6) VQ_K = 8 # codebook 大小 -VQ_D_Z = 128 # codebook 嵌入维度(三路保持一致,便于拼接) +VQ_D_Z = 64 # codebook 嵌入维度(三路保持一致,便于拼接) VQ_BETA = 0.5 # commit loss 权重 VQ_GAMMA = 0.0 # entropy loss 权重 # 各通道编码器配置 -CH1_D_MODEL = 128; CH1_NHEAD = 8 # 通道 1:空间骨架(floor+wall) -CH2_D_MODEL = 128; CH2_NHEAD = 8 # 通道 2:关卡门控 -CH3_D_MODEL = 128; CH3_NHEAD = 8 # 通道 3:收集资源 +CH1_D_MODEL = 64; CH1_NHEAD = 8 # 通道 1:空间骨架(floor+wall) +CH2_D_MODEL = 64; CH2_NHEAD = 8 # 通道 2:关卡门控 +CH3_D_MODEL = 64; CH3_NHEAD = 8 # 通道 3:收集资源 VQ_LAYERS = 3 VQ_DIM_FF = 512 diff --git a/train_full.sh b/train_full.sh index a375ee1..53b9639 100644 --- a/train_full.sh +++ b/train_full.sh @@ -24,13 +24,13 @@ TRAIN_DATA="ginka-dataset.json" EVAL_DATA="ginka-eval.json" # 阶段 0:三通道分拆预训练 -P0_EPOCHS=30 -P0_CHECKPOINT=10 +P0_EPOCHS=10 +P0_CHECKPOINT=5 P0_FINAL="result/pretrain_split/split_final.pth" # 阶段 1:冻结编码器热身 -P1_EPOCHS=50 -P1_CHECKPOINT=10 +P1_EPOCHS=10 +P1_CHECKPOINT=5 P1_FINAL="result/joint/warmup_final.pth" # 阶段 2:完整联合训练