From 7535ecc9fe0d984924c13da71ae56785114abde1 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Wed, 6 May 2026 22:47:08 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E8=B0=83=E6=95=B4=E6=8D=9F=E5=A4=B1?= =?UTF-8?q?=E5=80=BC=E8=AE=A1=E7=AE=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_pretrain_split.py | 6 +++--- ginka/utils.py | 4 ++-- train_full.sh | 10 +++++----- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/ginka/train_pretrain_split.py b/ginka/train_pretrain_split.py index ca6033a..9841306 100644 --- a/ginka/train_pretrain_split.py +++ b/ginka/train_pretrain_split.py @@ -37,7 +37,7 @@ from .utils import masked_focal BATCH_SIZE = 64 NUM_CLASSES = 7 MAP_SIZE = 13 * 13 -FOCAL_GAMMA = 2.0 +FOCAL_GAMMA = 1.0 # 通道 1:空间骨架(floor+wall) CH1_KEEP = {0, 1} # 编码器输入保留的 tile @@ -47,13 +47,13 @@ CH1_NHEAD = 8 # 通道 2:关卡门控 CH2_KEEP = {0, 1, 2, 4, 5} -CH2_LOSS = {0, 1, 2, 4, 5} +CH2_LOSS = {2, 4, 5} CH2_D_MODEL = 64 CH2_NHEAD = 8 # 通道 3:收集资源 CH3_KEEP = None # 完整地图,无需切片 -CH3_LOSS = {0, 1, 2, 3, 4, 5} +CH3_LOSS = {3} CH3_D_MODEL = 64 CH3_NHEAD = 8 diff --git a/ginka/utils.py b/ginka/utils.py index 6231c84..62d36f0 100644 --- a/ginka/utils.py +++ b/ginka/utils.py @@ -75,9 +75,9 @@ def masked_focal( # count[0],导致 weight[0] 趋近于 0、非专属位置损失被消除的问题 class_weight = None if balance: - flat = target.view(-1) # [B*S] 原始标签 + flat = corrected.view(-1) # [B*S] 原始标签 counts = torch.bincount(flat, minlength=C).float() # [C] - class_weight = flat.numel() / (counts.clamp(min=1.0) * C) + class_weight = torch.sqrt(flat.numel() / (counts.clamp(min=1.0) * C)) class_weight[counts == 0] = 0.0 # 未出现类别不参与 ce = F.cross_entropy( diff --git a/train_full.sh b/train_full.sh index 53b9639..f1c894b 100644 --- a/train_full.sh +++ b/train_full.sh @@ -24,17 +24,17 @@ TRAIN_DATA="ginka-dataset.json" EVAL_DATA="ginka-eval.json" # 阶段 0:三通道分拆预训练 -P0_EPOCHS=10 -P0_CHECKPOINT=5 +P0_EPOCHS=20 +P0_CHECKPOINT=10 P0_FINAL="result/pretrain_split/split_final.pth" # 阶段 1:冻结编码器热身 -P1_EPOCHS=10 -P1_CHECKPOINT=5 +P1_EPOCHS=30 +P1_CHECKPOINT=10 P1_FINAL="result/joint/warmup_final.pth" # 阶段 2:完整联合训练 -P2_EPOCHS=400 +P2_EPOCHS=470 P2_CHECKPOINT=20 # 从哪个阶段开始(0 = 从头);命令行 --skip N 可覆盖此值