diff --git a/ginka/train_maskGIT.py b/ginka/train_maskGIT.py index f80f446..0903464 100644 --- a/ginka/train_maskGIT.py +++ b/ginka/train_maskGIT.py @@ -37,7 +37,7 @@ from .maskGIT.mask import MapMask # 6. 道具热力图, 7. 入口热力图, 8. 门热力图 BATCH_SIZE = 128 -VAL_BATCH_DIVIDER = 128 +VAL_BATCH_DIVIDER = 64 NUM_CLASSES = 16 MASK_TOKEN = 15 GENERATE_STEP = 8 @@ -182,7 +182,7 @@ def train(): logits = model(masked_input, cond, heatmap) loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1) - loss = (loss * mask.view(-1)).sum() / (mask.sum() + 1e-6) + loss = (loss * mask).sum() / (mask.sum() + 1e-6) val_loss_total += loss.detach()