mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
chore: 调整损失值计算
This commit is contained in:
parent
f0025df1ec
commit
7535ecc9fe
@ -37,7 +37,7 @@ from .utils import masked_focal
|
|||||||
BATCH_SIZE = 64
|
BATCH_SIZE = 64
|
||||||
NUM_CLASSES = 7
|
NUM_CLASSES = 7
|
||||||
MAP_SIZE = 13 * 13
|
MAP_SIZE = 13 * 13
|
||||||
FOCAL_GAMMA = 2.0
|
FOCAL_GAMMA = 1.0
|
||||||
|
|
||||||
# 通道 1:空间骨架(floor+wall)
|
# 通道 1:空间骨架(floor+wall)
|
||||||
CH1_KEEP = {0, 1} # 编码器输入保留的 tile
|
CH1_KEEP = {0, 1} # 编码器输入保留的 tile
|
||||||
@ -47,13 +47,13 @@ CH1_NHEAD = 8
|
|||||||
|
|
||||||
# 通道 2:关卡门控
|
# 通道 2:关卡门控
|
||||||
CH2_KEEP = {0, 1, 2, 4, 5}
|
CH2_KEEP = {0, 1, 2, 4, 5}
|
||||||
CH2_LOSS = {0, 1, 2, 4, 5}
|
CH2_LOSS = {2, 4, 5}
|
||||||
CH2_D_MODEL = 64
|
CH2_D_MODEL = 64
|
||||||
CH2_NHEAD = 8
|
CH2_NHEAD = 8
|
||||||
|
|
||||||
# 通道 3:收集资源
|
# 通道 3:收集资源
|
||||||
CH3_KEEP = None # 完整地图,无需切片
|
CH3_KEEP = None # 完整地图,无需切片
|
||||||
CH3_LOSS = {0, 1, 2, 3, 4, 5}
|
CH3_LOSS = {3}
|
||||||
CH3_D_MODEL = 64
|
CH3_D_MODEL = 64
|
||||||
CH3_NHEAD = 8
|
CH3_NHEAD = 8
|
||||||
|
|
||||||
|
|||||||
@ -75,9 +75,9 @@ def masked_focal(
|
|||||||
# count[0],导致 weight[0] 趋近于 0、非专属位置损失被消除的问题
|
# count[0],导致 weight[0] 趋近于 0、非专属位置损失被消除的问题
|
||||||
class_weight = None
|
class_weight = None
|
||||||
if balance:
|
if balance:
|
||||||
flat = target.view(-1) # [B*S] 原始标签
|
flat = corrected.view(-1) # [B*S] 原始标签
|
||||||
counts = torch.bincount(flat, minlength=C).float() # [C]
|
counts = torch.bincount(flat, minlength=C).float() # [C]
|
||||||
class_weight = flat.numel() / (counts.clamp(min=1.0) * C)
|
class_weight = torch.sqrt(flat.numel() / (counts.clamp(min=1.0) * C))
|
||||||
class_weight[counts == 0] = 0.0 # 未出现类别不参与
|
class_weight[counts == 0] = 0.0 # 未出现类别不参与
|
||||||
|
|
||||||
ce = F.cross_entropy(
|
ce = F.cross_entropy(
|
||||||
|
|||||||
@ -24,17 +24,17 @@ TRAIN_DATA="ginka-dataset.json"
|
|||||||
EVAL_DATA="ginka-eval.json"
|
EVAL_DATA="ginka-eval.json"
|
||||||
|
|
||||||
# 阶段 0:三通道分拆预训练
|
# 阶段 0:三通道分拆预训练
|
||||||
P0_EPOCHS=10
|
P0_EPOCHS=20
|
||||||
P0_CHECKPOINT=5
|
P0_CHECKPOINT=10
|
||||||
P0_FINAL="result/pretrain_split/split_final.pth"
|
P0_FINAL="result/pretrain_split/split_final.pth"
|
||||||
|
|
||||||
# 阶段 1:冻结编码器热身
|
# 阶段 1:冻结编码器热身
|
||||||
P1_EPOCHS=10
|
P1_EPOCHS=30
|
||||||
P1_CHECKPOINT=5
|
P1_CHECKPOINT=10
|
||||||
P1_FINAL="result/joint/warmup_final.pth"
|
P1_FINAL="result/joint/warmup_final.pth"
|
||||||
|
|
||||||
# 阶段 2:完整联合训练
|
# 阶段 2:完整联合训练
|
||||||
P2_EPOCHS=400
|
P2_EPOCHS=470
|
||||||
P2_CHECKPOINT=20
|
P2_CHECKPOINT=20
|
||||||
|
|
||||||
# 从哪个阶段开始(0 = 从头);命令行 --skip N 可覆盖此值
|
# 从哪个阶段开始(0 = 从头);命令行 --skip N 可覆盖此值
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user