From 850c038be3dffac24ee6db90bf22d469f54efd1d Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Tue, 5 May 2026 21:49:44 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E4=BF=AE=E6=94=B9=E8=B6=85=E5=8F=82?= =?UTF-8?q?=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_pretrain_split.py | 36 ++++++++++++++++++++--------------- ginka/train_vq.py | 18 +++++++++--------- train_full.sh | 2 +- 3 files changed, 31 insertions(+), 25 deletions(-) diff --git a/ginka/train_pretrain_split.py b/ginka/train_pretrain_split.py index 9cad109..5e68c5c 100644 --- a/ginka/train_pretrain_split.py +++ b/ginka/train_pretrain_split.py @@ -41,35 +41,35 @@ FOCAL_GAMMA = 2.0 # 通道 1:空间骨架(floor+wall) CH1_KEEP = {0, 1} # 编码器输入保留的 tile -CH1_LOSS = {1} # 损失计算范围(仅 wall) +CH1_LOSS = {0, 1} # 损失计算范围(仅 wall) CH1_D_MODEL = 128 -CH1_NHEAD = 4 +CH1_NHEAD = 8 # 通道 2:关卡门控 CH2_KEEP = {0, 1, 2, 9, 10} CH2_LOSS = {2, 9, 10} -CH2_D_MODEL = 64 -CH2_NHEAD = 4 +CH2_D_MODEL = 128 +CH2_NHEAD = 8 # 通道 3:收集资源 CH3_KEEP = None # 完整地图,无需切片 CH3_LOSS = {3, 4, 5, 6, 7, 8} -CH3_D_MODEL = 64 -CH3_NHEAD = 4 +CH3_D_MODEL = 128 +CH3_NHEAD = 8 # 三路共用的 VQ 超参 VQ_L = 2 -VQ_K = 16 -VQ_D_Z = 64 -VQ_LAYERS = 2 -VQ_DIM_FF = 256 -VQ_BETA = 0.25 # commit loss 权重 -VQ_GAMMA = 0.1 # entropy loss 权重 +VQ_K = 8 +VQ_D_Z = 128 +VQ_LAYERS = 3 +VQ_DIM_FF = 512 +VQ_BETA = 0.5 # commit loss 权重 +VQ_GAMMA = 0.0 # entropy loss 权重 # 解码头超参(三路共用相同规格) -DH_NHEAD = 4 -DH_DIM_FF = 256 -DH_LAYERS = 2 +DH_NHEAD = 8 +DH_DIM_FF = 512 +DH_LAYERS = 3 # --------------------------------------------------------------------------- # 设备 @@ -258,6 +258,12 @@ def train(): sum(p.numel() for p in enc3.parameters()) ) print(f"编码器总参数量(三路): {total_params:,} ({total_params / 1e6:.3f}M)") + total_params = ( + sum(p.numel() for p in head1.parameters()) + + sum(p.numel() for p in head2.parameters()) + + sum(p.numel() for p in head3.parameters()) + ) + print(f"解码器总参数量(三路): {total_params:,} ({total_params / 1e6:.3f}M)") # ---- 训练循环 ---- for epoch in range(start_epoch, args.epochs): diff --git a/ginka/train_vq.py b/ginka/train_vq.py index 05a56f4..29bf6e4 100644 --- a/ginka/train_vq.py +++ b/ginka/train_vq.py @@ -47,17 +47,17 @@ WALL_MASK_RATIO = 0.8 # VQ-VAE 公共超参(三路编码器共用,方案 B 三通道分拆) VQ_L = 2 # 每路码字序列长度(三路合计 L1+L2+L3 = 6) -VQ_K = 16 # codebook 大小 -VQ_D_Z = 64 # codebook 嵌入维度(三路保持一致,便于拼接) -VQ_BETA = 0.25 # commit loss 权重 -VQ_GAMMA = 0.1 # entropy loss 权重 +VQ_K = 8 # codebook 大小 +VQ_D_Z = 128 # codebook 嵌入维度(三路保持一致,便于拼接) +VQ_BETA = 0.5 # commit loss 权重 +VQ_GAMMA = 0.0 # entropy loss 权重 # 各通道编码器配置 -CH1_D_MODEL = 128; CH1_NHEAD = 4 # 通道 1:空间骨架(floor+wall) -CH2_D_MODEL = 64; CH2_NHEAD = 4 # 通道 2:关卡门控 -CH3_D_MODEL = 64; CH3_NHEAD = 4 # 通道 3:收集资源 -VQ_LAYERS = 2 -VQ_DIM_FF = 256 +CH1_D_MODEL = 128; CH1_NHEAD = 8 # 通道 1:空间骨架(floor+wall) +CH2_D_MODEL = 128; CH2_NHEAD = 8 # 通道 2:关卡门控 +CH3_D_MODEL = 128; CH3_NHEAD = 8 # 通道 3:收集资源 +VQ_LAYERS = 3 +VQ_DIM_FF = 512 # 通道专属损失计算范围(用于监控验证召回率) CH1_LOSS = {1} diff --git a/train_full.sh b/train_full.sh index 5de90a2..a375ee1 100644 --- a/train_full.sh +++ b/train_full.sh @@ -24,7 +24,7 @@ TRAIN_DATA="ginka-dataset.json" EVAL_DATA="ginka-eval.json" # 阶段 0:三通道分拆预训练 -P0_EPOCHS=100 +P0_EPOCHS=30 P0_CHECKPOINT=10 P0_FINAL="result/pretrain_split/split_final.pth"