mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 损失值改为 focal loss
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
f22943820c
commit
6460d8c5bc
@ -50,6 +50,9 @@ VQ_DIM_FF = 512
|
||||
VQ_BETA = 0.5
|
||||
VQ_GAMMA = 0.0
|
||||
|
||||
# Focal Loss
|
||||
FOCAL_GAMMA = 2.0 # focal loss 聚焦参数(越大越关注难例/稀有类别)
|
||||
|
||||
# 解码头超参(与编码器对称:同等层数和 FFN 宽度)
|
||||
DH_NHEAD = 8 # Cross-Attention 头数(VQ_D_Z=128 可被 8 整除)
|
||||
DH_DIM_FF = 512 # FFN 隐层维度(与编码器 VQ_DIM_FF 一致)
|
||||
@ -68,6 +71,24 @@ os.makedirs("result/pretrain", exist_ok=True)
|
||||
|
||||
disable_tqdm = not sys.stdout.isatty()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Focal Loss
|
||||
# ---------------------------------------------------------------------------
|
||||
def focal_loss(
|
||||
logits: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
gamma: float = FOCAL_GAMMA,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
多分类 Focal Loss(mean 归约):FL = -(1 - p_t)^gamma * log(p_t)
|
||||
|
||||
相比 CE,对已被正确分类的高置信度样本施加更小的权重,
|
||||
迫使模型关注难分类的稀有 tile(门/怪/资源等)。
|
||||
"""
|
||||
ce = F.cross_entropy(logits, targets, reduction='none')
|
||||
pt = torch.exp(-ce)
|
||||
return ((1.0 - pt) ** gamma * ce).mean()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 简单数据集:仅返回 raw_map,无子集划分,无掩码
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -246,11 +267,9 @@ def train():
|
||||
# 1. 编码
|
||||
z_q, _, _, vq_loss, commit_loss, entropy_loss = model_vq(raw_map)
|
||||
|
||||
# 2. 解码→全图重建
|
||||
# 2. 解码→全图重建(focal loss 缓解墙壁/空地主导问题)
|
||||
logits = decode_head(z_q) # [B, H*W, C]
|
||||
ce_loss = F.cross_entropy(
|
||||
logits.permute(0, 2, 1), raw_map # [B, C, H*W] vs [B, H*W]
|
||||
)
|
||||
ce_loss = focal_loss(logits.permute(0, 2, 1), raw_map)
|
||||
|
||||
# 3. 总损失(重建 + VQ 正则)
|
||||
loss = ce_loss + vq_loss
|
||||
|
||||
@ -42,7 +42,7 @@ MASK_TOKEN = 15
|
||||
GENERATE_STEP = 18 # 推理时 MaskGIT 迭代步数
|
||||
MAP_SIZE = 13 * 13
|
||||
MAP_H = MAP_W = 13
|
||||
LABEL_SMOOTHING = 0.0
|
||||
FOCAL_GAMMA = 2.0 # focal loss 聚焦参数(越大越关注难例/稀有类别)
|
||||
WALL_MASK_RATIO = 0.8
|
||||
|
||||
# VQ-VAE 超参
|
||||
@ -107,6 +107,36 @@ def parse_arguments():
|
||||
"适用于预训练权重加载后的热身阶段。")
|
||||
return parser.parse_args()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Focal Loss
|
||||
# ---------------------------------------------------------------------------
|
||||
def focal_loss(
|
||||
logits: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
gamma: float = FOCAL_GAMMA,
|
||||
reduction: str = 'none',
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
多分类 Focal Loss:FL = -(1 - p_t)^gamma * log(p_t)
|
||||
|
||||
相比 CE,对已被正确分类的高置信度样本施加更小的权重,
|
||||
迫使模型关注难分类的稀有 tile(门/怪/资源等)。
|
||||
|
||||
Args:
|
||||
logits: [B, C, *] 未经 softmax 的原始预测
|
||||
targets: [B, *] 整数类别标签
|
||||
gamma: 聚焦参数,0 时退化为标准 CE
|
||||
reduction: 'none' | 'mean' | 'sum'
|
||||
"""
|
||||
ce = F.cross_entropy(logits, targets, reduction='none') # [B, *]
|
||||
pt = torch.exp(-ce) # 正确类的预测概率
|
||||
fl = (1.0 - pt) ** gamma * ce
|
||||
if reduction == 'mean':
|
||||
return fl.mean()
|
||||
if reduction == 'sum':
|
||||
return fl.sum()
|
||||
return fl # 'none'
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MaskGIT 推理(cosine schedule 迭代解码)
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -367,10 +397,7 @@ def validate(
|
||||
logits = model_mg(masked_map, z_q, struct_cond=struct_cond_b)
|
||||
mask = (masked_map == MASK_TOKEN)
|
||||
|
||||
ce_loss = F.cross_entropy(
|
||||
logits.permute(0, 2, 1), target_map,
|
||||
reduction='none', label_smoothing=LABEL_SMOOTHING
|
||||
)
|
||||
ce_loss = focal_loss(logits.permute(0, 2, 1), target_map)
|
||||
masked_ce = (ce_loss * mask).sum() / (mask.sum() + 1e-6)
|
||||
val_loss_total += (masked_ce + vq_loss).item()
|
||||
val_steps += 1
|
||||
@ -623,12 +650,9 @@ def train():
|
||||
struct_cond = batch["struct_cond"].to(device) # [B, 4]
|
||||
logits = model_mg(masked_map, z_q, struct_cond=struct_cond) # [B, 169, C]
|
||||
|
||||
# 3. 只对被 mask 的位置计算 CE loss
|
||||
# 3. 只对被 mask 的位置计算 focal loss(缓解墙壁/空地主导问题)
|
||||
mask = (masked_map == MASK_TOKEN) # [B, 169] bool
|
||||
ce_loss = F.cross_entropy(
|
||||
logits.permute(0, 2, 1), target_map,
|
||||
reduction='none', label_smoothing=LABEL_SMOOTHING
|
||||
)
|
||||
ce_loss = focal_loss(logits.permute(0, 2, 1), target_map)
|
||||
masked_ce = (ce_loss * mask).sum() / (mask.sum() + 1e-6)
|
||||
|
||||
# 4. z 一致性约束(方案 A):将 MaskGIT 的 logits 经温度平滑后
|
||||
|
||||
Loading…
Reference in New Issue
Block a user