From c9bb50d503407990d30a78a52a675fe8e71186c4 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Wed, 11 Mar 2026 16:43:00 +0800 Subject: [PATCH] fix: train maskgit --- ginka/train_maskGIT.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ginka/train_maskGIT.py b/ginka/train_maskGIT.py index bfb8bb4..5cb1a51 100644 --- a/ginka/train_maskGIT.py +++ b/ginka/train_maskGIT.py @@ -42,8 +42,8 @@ from .maskGIT.mask import MapMask # 0. 空地, 1. 墙壁, 2. 门, 3. 钥匙, 4. 红宝石, 5. 蓝宝石, 6. 绿宝石, 7. 血瓶 # 8. 道具, 9. 怪物, 10. 入口, 15. 掩码 token -BATCH_SIZE = 16 -VAL_BATCH_DIVIDER = 16 +BATCH_SIZE = 128 +VAL_BATCH_DIVIDER = 128 NUM_CLASSES = 16 MASK_TOKEN = 15 GENERATE_STEP = 8 @@ -126,7 +126,7 @@ def train(): for i in range(B): mask[i] = masker.mask(H, W) - mask = torch.from_numpy(mask).to(torch.bool) + mask = torch.from_numpy(mask).to(torch.bool).to(device) # 掩码 masked_input = target_map.clone() @@ -178,7 +178,7 @@ def train(): for i in range(B): mask[i] = masker.mask(H, W) - mask = torch.from_numpy(mask).to(torch.bool) + mask = torch.from_numpy(mask).to(torch.bool).to(device) # 2. 生成掩码矩阵 masked_input = target_map.clone() @@ -204,7 +204,7 @@ def train(): map = torch.full((1, MAP_SIZE), MASK_TOKEN).to(device) for i in range(GENERATE_STEP): # 1. 预测 - logits = model(map, cond) # [1, H * W, num_classes] + logits = model(map, cond, heatmap) # [1, H * W, num_classes] probs = F.softmax(logits, dim=-1) # 2. 采样(为了多样性,这里可以使用概率采样而不是取最大值)