mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-21 02:11:13 +08:00
chore: 修改超参数
This commit is contained in:
parent
3e273c5a9d
commit
850c038be3
@ -41,35 +41,35 @@ FOCAL_GAMMA = 2.0
|
|||||||
|
|
||||||
# 通道 1:空间骨架(floor+wall)
|
# 通道 1:空间骨架(floor+wall)
|
||||||
CH1_KEEP = {0, 1} # 编码器输入保留的 tile
|
CH1_KEEP = {0, 1} # 编码器输入保留的 tile
|
||||||
CH1_LOSS = {1} # 损失计算范围(仅 wall)
|
CH1_LOSS = {0, 1} # 损失计算范围(仅 wall)
|
||||||
CH1_D_MODEL = 128
|
CH1_D_MODEL = 128
|
||||||
CH1_NHEAD = 4
|
CH1_NHEAD = 8
|
||||||
|
|
||||||
# 通道 2:关卡门控
|
# 通道 2:关卡门控
|
||||||
CH2_KEEP = {0, 1, 2, 9, 10}
|
CH2_KEEP = {0, 1, 2, 9, 10}
|
||||||
CH2_LOSS = {2, 9, 10}
|
CH2_LOSS = {2, 9, 10}
|
||||||
CH2_D_MODEL = 64
|
CH2_D_MODEL = 128
|
||||||
CH2_NHEAD = 4
|
CH2_NHEAD = 8
|
||||||
|
|
||||||
# 通道 3:收集资源
|
# 通道 3:收集资源
|
||||||
CH3_KEEP = None # 完整地图,无需切片
|
CH3_KEEP = None # 完整地图,无需切片
|
||||||
CH3_LOSS = {3, 4, 5, 6, 7, 8}
|
CH3_LOSS = {3, 4, 5, 6, 7, 8}
|
||||||
CH3_D_MODEL = 64
|
CH3_D_MODEL = 128
|
||||||
CH3_NHEAD = 4
|
CH3_NHEAD = 8
|
||||||
|
|
||||||
# 三路共用的 VQ 超参
|
# 三路共用的 VQ 超参
|
||||||
VQ_L = 2
|
VQ_L = 2
|
||||||
VQ_K = 16
|
VQ_K = 8
|
||||||
VQ_D_Z = 64
|
VQ_D_Z = 128
|
||||||
VQ_LAYERS = 2
|
VQ_LAYERS = 3
|
||||||
VQ_DIM_FF = 256
|
VQ_DIM_FF = 512
|
||||||
VQ_BETA = 0.25 # commit loss 权重
|
VQ_BETA = 0.5 # commit loss 权重
|
||||||
VQ_GAMMA = 0.1 # entropy loss 权重
|
VQ_GAMMA = 0.0 # entropy loss 权重
|
||||||
|
|
||||||
# 解码头超参(三路共用相同规格)
|
# 解码头超参(三路共用相同规格)
|
||||||
DH_NHEAD = 4
|
DH_NHEAD = 8
|
||||||
DH_DIM_FF = 256
|
DH_DIM_FF = 512
|
||||||
DH_LAYERS = 2
|
DH_LAYERS = 3
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# 设备
|
# 设备
|
||||||
@ -258,6 +258,12 @@ def train():
|
|||||||
sum(p.numel() for p in enc3.parameters())
|
sum(p.numel() for p in enc3.parameters())
|
||||||
)
|
)
|
||||||
print(f"编码器总参数量(三路): {total_params:,} ({total_params / 1e6:.3f}M)")
|
print(f"编码器总参数量(三路): {total_params:,} ({total_params / 1e6:.3f}M)")
|
||||||
|
total_params = (
|
||||||
|
sum(p.numel() for p in head1.parameters()) +
|
||||||
|
sum(p.numel() for p in head2.parameters()) +
|
||||||
|
sum(p.numel() for p in head3.parameters())
|
||||||
|
)
|
||||||
|
print(f"解码器总参数量(三路): {total_params:,} ({total_params / 1e6:.3f}M)")
|
||||||
|
|
||||||
# ---- 训练循环 ----
|
# ---- 训练循环 ----
|
||||||
for epoch in range(start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
|
|||||||
@ -47,17 +47,17 @@ WALL_MASK_RATIO = 0.8
|
|||||||
|
|
||||||
# VQ-VAE 公共超参(三路编码器共用,方案 B 三通道分拆)
|
# VQ-VAE 公共超参(三路编码器共用,方案 B 三通道分拆)
|
||||||
VQ_L = 2 # 每路码字序列长度(三路合计 L1+L2+L3 = 6)
|
VQ_L = 2 # 每路码字序列长度(三路合计 L1+L2+L3 = 6)
|
||||||
VQ_K = 16 # codebook 大小
|
VQ_K = 8 # codebook 大小
|
||||||
VQ_D_Z = 64 # codebook 嵌入维度(三路保持一致,便于拼接)
|
VQ_D_Z = 128 # codebook 嵌入维度(三路保持一致,便于拼接)
|
||||||
VQ_BETA = 0.25 # commit loss 权重
|
VQ_BETA = 0.5 # commit loss 权重
|
||||||
VQ_GAMMA = 0.1 # entropy loss 权重
|
VQ_GAMMA = 0.0 # entropy loss 权重
|
||||||
|
|
||||||
# 各通道编码器配置
|
# 各通道编码器配置
|
||||||
CH1_D_MODEL = 128; CH1_NHEAD = 4 # 通道 1:空间骨架(floor+wall)
|
CH1_D_MODEL = 128; CH1_NHEAD = 8 # 通道 1:空间骨架(floor+wall)
|
||||||
CH2_D_MODEL = 64; CH2_NHEAD = 4 # 通道 2:关卡门控
|
CH2_D_MODEL = 128; CH2_NHEAD = 8 # 通道 2:关卡门控
|
||||||
CH3_D_MODEL = 64; CH3_NHEAD = 4 # 通道 3:收集资源
|
CH3_D_MODEL = 128; CH3_NHEAD = 8 # 通道 3:收集资源
|
||||||
VQ_LAYERS = 2
|
VQ_LAYERS = 3
|
||||||
VQ_DIM_FF = 256
|
VQ_DIM_FF = 512
|
||||||
|
|
||||||
# 通道专属损失计算范围(用于监控验证召回率)
|
# 通道专属损失计算范围(用于监控验证召回率)
|
||||||
CH1_LOSS = {1}
|
CH1_LOSS = {1}
|
||||||
|
|||||||
@ -24,7 +24,7 @@ TRAIN_DATA="ginka-dataset.json"
|
|||||||
EVAL_DATA="ginka-eval.json"
|
EVAL_DATA="ginka-eval.json"
|
||||||
|
|
||||||
# 阶段 0:三通道分拆预训练
|
# 阶段 0:三通道分拆预训练
|
||||||
P0_EPOCHS=100
|
P0_EPOCHS=30
|
||||||
P0_CHECKPOINT=10
|
P0_CHECKPOINT=10
|
||||||
P0_FINAL="result/pretrain_split/split_final.pth"
|
P0_FINAL="result/pretrain_split/split_final.pth"
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user