fix: 验证

This commit is contained in:
unanmed 2026-03-11 18:26:15 +08:00
parent 19fadf50bd
commit f35d0b03af

View File

@ -37,7 +37,7 @@ from .maskGIT.mask import MapMask
# 6. 道具热力图, 7. 入口热力图, 8. 门热力图
BATCH_SIZE = 128
VAL_BATCH_DIVIDER = 128
VAL_BATCH_DIVIDER = 64
NUM_CLASSES = 16
MASK_TOKEN = 15
GENERATE_STEP = 8
@ -182,7 +182,7 @@ def train():
logits = model(masked_input, cond, heatmap)
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1)
loss = (loss * mask.view(-1)).sum() / (mask.sum() + 1e-6)
loss = (loss * mask).sum() / (mask.sum() + 1e-6)
val_loss_total += loss.detach()