From f35d0b03af9ba2681d9e90eed6df740937a965e9 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Wed, 11 Mar 2026 18:26:15 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E9=AA=8C=E8=AF=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_maskGIT.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ginka/train_maskGIT.py b/ginka/train_maskGIT.py index f80f446..0903464 100644 --- a/ginka/train_maskGIT.py +++ b/ginka/train_maskGIT.py @@ -37,7 +37,7 @@ from .maskGIT.mask import MapMask # 6. 道具热力图, 7. 入口热力图, 8. 门热力图 BATCH_SIZE = 128 -VAL_BATCH_DIVIDER = 128 +VAL_BATCH_DIVIDER = 64 NUM_CLASSES = 16 MASK_TOKEN = 15 GENERATE_STEP = 8 @@ -182,7 +182,7 @@ def train(): logits = model(masked_input, cond, heatmap) loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1) - loss = (loss * mask.view(-1)).sum() / (mask.sum() + 1e-6) + loss = (loss * mask).sum() / (mask.sum() + 1e-6) val_loss_total += loss.detach()