feat: 损失值改为 focal loss

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
unanmed 2026-04-28 17:23:51 +08:00
parent f22943820c
commit 6460d8c5bc
2 changed files with 57 additions and 14 deletions

View File

@ -50,6 +50,9 @@ VQ_DIM_FF = 512
VQ_BETA = 0.5 VQ_BETA = 0.5
VQ_GAMMA = 0.0 VQ_GAMMA = 0.0
# Focal Loss
FOCAL_GAMMA = 2.0 # focal loss 聚焦参数(越大越关注难例/稀有类别)
# 解码头超参(与编码器对称:同等层数和 FFN 宽度) # 解码头超参(与编码器对称:同等层数和 FFN 宽度)
DH_NHEAD = 8 # Cross-Attention 头数VQ_D_Z=128 可被 8 整除) DH_NHEAD = 8 # Cross-Attention 头数VQ_D_Z=128 可被 8 整除)
DH_DIM_FF = 512 # FFN 隐层维度(与编码器 VQ_DIM_FF 一致) DH_DIM_FF = 512 # FFN 隐层维度(与编码器 VQ_DIM_FF 一致)
@ -68,6 +71,24 @@ os.makedirs("result/pretrain", exist_ok=True)
disable_tqdm = not sys.stdout.isatty() disable_tqdm = not sys.stdout.isatty()
# ---------------------------------------------------------------------------
# Focal Loss
# ---------------------------------------------------------------------------
def focal_loss(
logits: torch.Tensor,
targets: torch.Tensor,
gamma: float = FOCAL_GAMMA,
) -> torch.Tensor:
"""
多分类 Focal Lossmean 归约FL = -(1 - p_t)^gamma * log(p_t)
相比 CE对已被正确分类的高置信度样本施加更小的权重
迫使模型关注难分类的稀有 tile//资源等
"""
ce = F.cross_entropy(logits, targets, reduction='none')
pt = torch.exp(-ce)
return ((1.0 - pt) ** gamma * ce).mean()
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# 简单数据集:仅返回 raw_map无子集划分无掩码 # 简单数据集:仅返回 raw_map无子集划分无掩码
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -246,11 +267,9 @@ def train():
# 1. 编码 # 1. 编码
z_q, _, _, vq_loss, commit_loss, entropy_loss = model_vq(raw_map) z_q, _, _, vq_loss, commit_loss, entropy_loss = model_vq(raw_map)
# 2. 解码→全图重建 # 2. 解码→全图重建focal loss 缓解墙壁/空地主导问题)
logits = decode_head(z_q) # [B, H*W, C] logits = decode_head(z_q) # [B, H*W, C]
ce_loss = F.cross_entropy( ce_loss = focal_loss(logits.permute(0, 2, 1), raw_map)
logits.permute(0, 2, 1), raw_map # [B, C, H*W] vs [B, H*W]
)
# 3. 总损失(重建 + VQ 正则) # 3. 总损失(重建 + VQ 正则)
loss = ce_loss + vq_loss loss = ce_loss + vq_loss

View File

@ -42,7 +42,7 @@ MASK_TOKEN = 15
GENERATE_STEP = 18 # 推理时 MaskGIT 迭代步数 GENERATE_STEP = 18 # 推理时 MaskGIT 迭代步数
MAP_SIZE = 13 * 13 MAP_SIZE = 13 * 13
MAP_H = MAP_W = 13 MAP_H = MAP_W = 13
LABEL_SMOOTHING = 0.0 FOCAL_GAMMA = 2.0 # focal loss 聚焦参数(越大越关注难例/稀有类别)
WALL_MASK_RATIO = 0.8 WALL_MASK_RATIO = 0.8
# VQ-VAE 超参 # VQ-VAE 超参
@ -107,6 +107,36 @@ def parse_arguments():
"适用于预训练权重加载后的热身阶段。") "适用于预训练权重加载后的热身阶段。")
return parser.parse_args() return parser.parse_args()
# ---------------------------------------------------------------------------
# Focal Loss
# ---------------------------------------------------------------------------
def focal_loss(
logits: torch.Tensor,
targets: torch.Tensor,
gamma: float = FOCAL_GAMMA,
reduction: str = 'none',
) -> torch.Tensor:
"""
多分类 Focal LossFL = -(1 - p_t)^gamma * log(p_t)
相比 CE对已被正确分类的高置信度样本施加更小的权重
迫使模型关注难分类的稀有 tile//资源等
Args:
logits: [B, C, *] 未经 softmax 的原始预测
targets: [B, *] 整数类别标签
gamma: 聚焦参数0 时退化为标准 CE
reduction: 'none' | 'mean' | 'sum'
"""
ce = F.cross_entropy(logits, targets, reduction='none') # [B, *]
pt = torch.exp(-ce) # 正确类的预测概率
fl = (1.0 - pt) ** gamma * ce
if reduction == 'mean':
return fl.mean()
if reduction == 'sum':
return fl.sum()
return fl # 'none'
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# MaskGIT 推理cosine schedule 迭代解码) # MaskGIT 推理cosine schedule 迭代解码)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -367,10 +397,7 @@ def validate(
logits = model_mg(masked_map, z_q, struct_cond=struct_cond_b) logits = model_mg(masked_map, z_q, struct_cond=struct_cond_b)
mask = (masked_map == MASK_TOKEN) mask = (masked_map == MASK_TOKEN)
ce_loss = F.cross_entropy( ce_loss = focal_loss(logits.permute(0, 2, 1), target_map)
logits.permute(0, 2, 1), target_map,
reduction='none', label_smoothing=LABEL_SMOOTHING
)
masked_ce = (ce_loss * mask).sum() / (mask.sum() + 1e-6) masked_ce = (ce_loss * mask).sum() / (mask.sum() + 1e-6)
val_loss_total += (masked_ce + vq_loss).item() val_loss_total += (masked_ce + vq_loss).item()
val_steps += 1 val_steps += 1
@ -623,12 +650,9 @@ def train():
struct_cond = batch["struct_cond"].to(device) # [B, 4] struct_cond = batch["struct_cond"].to(device) # [B, 4]
logits = model_mg(masked_map, z_q, struct_cond=struct_cond) # [B, 169, C] logits = model_mg(masked_map, z_q, struct_cond=struct_cond) # [B, 169, C]
# 3. 只对被 mask 的位置计算 CE loss # 3. 只对被 mask 的位置计算 focal loss缓解墙壁/空地主导问题)
mask = (masked_map == MASK_TOKEN) # [B, 169] bool mask = (masked_map == MASK_TOKEN) # [B, 169] bool
ce_loss = F.cross_entropy( ce_loss = focal_loss(logits.permute(0, 2, 1), target_map)
logits.permute(0, 2, 1), target_map,
reduction='none', label_smoothing=LABEL_SMOOTHING
)
masked_ce = (ce_loss * mask).sum() / (mask.sum() + 1e-6) masked_ce = (ce_loss * mask).sum() / (mask.sum() + 1e-6)
# 4. z 一致性约束(方案 A将 MaskGIT 的 logits 经温度平滑后 # 4. z 一致性约束(方案 A将 MaskGIT 的 logits 经温度平滑后