diff --git a/ginka/train_pretrain_split.py b/ginka/train_pretrain_split.py index ca6033a..9841306 100644 --- a/ginka/train_pretrain_split.py +++ b/ginka/train_pretrain_split.py @@ -37,7 +37,7 @@ from .utils import masked_focal BATCH_SIZE = 64 NUM_CLASSES = 7 MAP_SIZE = 13 * 13 -FOCAL_GAMMA = 2.0 +FOCAL_GAMMA = 1.0 # 通道 1:空间骨架(floor+wall) CH1_KEEP = {0, 1} # 编码器输入保留的 tile @@ -47,13 +47,13 @@ CH1_NHEAD = 8 # 通道 2:关卡门控 CH2_KEEP = {0, 1, 2, 4, 5} -CH2_LOSS = {0, 1, 2, 4, 5} +CH2_LOSS = {2, 4, 5} CH2_D_MODEL = 64 CH2_NHEAD = 8 # 通道 3:收集资源 CH3_KEEP = None # 完整地图,无需切片 -CH3_LOSS = {0, 1, 2, 3, 4, 5} +CH3_LOSS = {3} CH3_D_MODEL = 64 CH3_NHEAD = 8 diff --git a/ginka/utils.py b/ginka/utils.py index 6231c84..62d36f0 100644 --- a/ginka/utils.py +++ b/ginka/utils.py @@ -75,9 +75,9 @@ def masked_focal( # count[0],导致 weight[0] 趋近于 0、非专属位置损失被消除的问题 class_weight = None if balance: - flat = target.view(-1) # [B*S] 原始标签 + flat = corrected.view(-1) # [B*S] 原始标签 counts = torch.bincount(flat, minlength=C).float() # [C] - class_weight = flat.numel() / (counts.clamp(min=1.0) * C) + class_weight = torch.sqrt(flat.numel() / (counts.clamp(min=1.0) * C)) class_weight[counts == 0] = 0.0 # 未出现类别不参与 ce = F.cross_entropy( diff --git a/train_full.sh b/train_full.sh index 53b9639..f1c894b 100644 --- a/train_full.sh +++ b/train_full.sh @@ -24,17 +24,17 @@ TRAIN_DATA="ginka-dataset.json" EVAL_DATA="ginka-eval.json" # 阶段 0:三通道分拆预训练 -P0_EPOCHS=10 -P0_CHECKPOINT=5 +P0_EPOCHS=20 +P0_CHECKPOINT=10 P0_FINAL="result/pretrain_split/split_final.pth" # 阶段 1:冻结编码器热身 -P1_EPOCHS=10 -P1_CHECKPOINT=5 +P1_EPOCHS=30 +P1_CHECKPOINT=10 P1_FINAL="result/joint/warmup_final.pth" # 阶段 2:完整联合训练 -P2_EPOCHS=400 +P2_EPOCHS=470 P2_CHECKPOINT=20 # 从哪个阶段开始(0 = 从头);命令行 --skip N 可覆盖此值