diff --git a/ginka/train_pretrain.py b/ginka/train_pretrain.py index f27e903..9e304aa 100644 --- a/ginka/train_pretrain.py +++ b/ginka/train_pretrain.py @@ -50,6 +50,9 @@ VQ_DIM_FF = 512 VQ_BETA = 0.5 VQ_GAMMA = 0.0 +# Focal Loss +FOCAL_GAMMA = 2.0 # focal loss 聚焦参数(越大越关注难例/稀有类别) + # 解码头超参(与编码器对称:同等层数和 FFN 宽度) DH_NHEAD = 8 # Cross-Attention 头数(VQ_D_Z=128 可被 8 整除) DH_DIM_FF = 512 # FFN 隐层维度(与编码器 VQ_DIM_FF 一致) @@ -68,6 +71,24 @@ os.makedirs("result/pretrain", exist_ok=True) disable_tqdm = not sys.stdout.isatty() +# --------------------------------------------------------------------------- +# Focal Loss +# --------------------------------------------------------------------------- +def focal_loss( + logits: torch.Tensor, + targets: torch.Tensor, + gamma: float = FOCAL_GAMMA, +) -> torch.Tensor: + """ + 多分类 Focal Loss(mean 归约):FL = -(1 - p_t)^gamma * log(p_t) + + 相比 CE,对已被正确分类的高置信度样本施加更小的权重, + 迫使模型关注难分类的稀有 tile(门/怪/资源等)。 + """ + ce = F.cross_entropy(logits, targets, reduction='none') + pt = torch.exp(-ce) + return ((1.0 - pt) ** gamma * ce).mean() + # --------------------------------------------------------------------------- # 简单数据集:仅返回 raw_map,无子集划分,无掩码 # --------------------------------------------------------------------------- @@ -246,11 +267,9 @@ def train(): # 1. 编码 z_q, _, _, vq_loss, commit_loss, entropy_loss = model_vq(raw_map) - # 2. 解码→全图重建 + # 2. 解码→全图重建(focal loss 缓解墙壁/空地主导问题) logits = decode_head(z_q) # [B, H*W, C] - ce_loss = F.cross_entropy( - logits.permute(0, 2, 1), raw_map # [B, C, H*W] vs [B, H*W] - ) + ce_loss = focal_loss(logits.permute(0, 2, 1), raw_map) # 3. 总损失(重建 + VQ 正则) loss = ce_loss + vq_loss diff --git a/ginka/train_vq.py b/ginka/train_vq.py index e17328c..2f0bc50 100644 --- a/ginka/train_vq.py +++ b/ginka/train_vq.py @@ -42,7 +42,7 @@ MASK_TOKEN = 15 GENERATE_STEP = 18 # 推理时 MaskGIT 迭代步数 MAP_SIZE = 13 * 13 MAP_H = MAP_W = 13 -LABEL_SMOOTHING = 0.0 +FOCAL_GAMMA = 2.0 # focal loss 聚焦参数(越大越关注难例/稀有类别) WALL_MASK_RATIO = 0.8 # VQ-VAE 超参 @@ -107,6 +107,36 @@ def parse_arguments(): "适用于预训练权重加载后的热身阶段。") return parser.parse_args() +# --------------------------------------------------------------------------- +# Focal Loss +# --------------------------------------------------------------------------- +def focal_loss( + logits: torch.Tensor, + targets: torch.Tensor, + gamma: float = FOCAL_GAMMA, + reduction: str = 'none', +) -> torch.Tensor: + """ + 多分类 Focal Loss:FL = -(1 - p_t)^gamma * log(p_t) + + 相比 CE,对已被正确分类的高置信度样本施加更小的权重, + 迫使模型关注难分类的稀有 tile(门/怪/资源等)。 + + Args: + logits: [B, C, *] 未经 softmax 的原始预测 + targets: [B, *] 整数类别标签 + gamma: 聚焦参数,0 时退化为标准 CE + reduction: 'none' | 'mean' | 'sum' + """ + ce = F.cross_entropy(logits, targets, reduction='none') # [B, *] + pt = torch.exp(-ce) # 正确类的预测概率 + fl = (1.0 - pt) ** gamma * ce + if reduction == 'mean': + return fl.mean() + if reduction == 'sum': + return fl.sum() + return fl # 'none' + # --------------------------------------------------------------------------- # MaskGIT 推理(cosine schedule 迭代解码) # --------------------------------------------------------------------------- @@ -367,10 +397,7 @@ def validate( logits = model_mg(masked_map, z_q, struct_cond=struct_cond_b) mask = (masked_map == MASK_TOKEN) - ce_loss = F.cross_entropy( - logits.permute(0, 2, 1), target_map, - reduction='none', label_smoothing=LABEL_SMOOTHING - ) + ce_loss = focal_loss(logits.permute(0, 2, 1), target_map) masked_ce = (ce_loss * mask).sum() / (mask.sum() + 1e-6) val_loss_total += (masked_ce + vq_loss).item() val_steps += 1 @@ -623,12 +650,9 @@ def train(): struct_cond = batch["struct_cond"].to(device) # [B, 4] logits = model_mg(masked_map, z_q, struct_cond=struct_cond) # [B, 169, C] - # 3. 只对被 mask 的位置计算 CE loss + # 3. 只对被 mask 的位置计算 focal loss(缓解墙壁/空地主导问题) mask = (masked_map == MASK_TOKEN) # [B, 169] bool - ce_loss = F.cross_entropy( - logits.permute(0, 2, 1), target_map, - reduction='none', label_smoothing=LABEL_SMOOTHING - ) + ce_loss = focal_loss(logits.permute(0, 2, 1), target_map) masked_ce = (ce_loss * mask).sum() / (mask.sum() + 1e-6) # 4. z 一致性约束(方案 A):将 MaskGIT 的 logits 经温度平滑后