fix: train maskgit

This commit is contained in:
unanmed 2026-03-11 16:43:00 +08:00
parent c000b90794
commit c9bb50d503

View File

@ -42,8 +42,8 @@ from .maskGIT.mask import MapMask
# 0. 空地, 1. 墙壁, 2. 门, 3. 钥匙, 4. 红宝石, 5. 蓝宝石, 6. 绿宝石, 7. 血瓶
# 8. 道具, 9. 怪物, 10. 入口, 15. 掩码 token
BATCH_SIZE = 16
VAL_BATCH_DIVIDER = 16
BATCH_SIZE = 128
VAL_BATCH_DIVIDER = 128
NUM_CLASSES = 16
MASK_TOKEN = 15
GENERATE_STEP = 8
@ -126,7 +126,7 @@ def train():
for i in range(B):
mask[i] = masker.mask(H, W)
mask = torch.from_numpy(mask).to(torch.bool)
mask = torch.from_numpy(mask).to(torch.bool).to(device)
# 掩码
masked_input = target_map.clone()
@ -178,7 +178,7 @@ def train():
for i in range(B):
mask[i] = masker.mask(H, W)
mask = torch.from_numpy(mask).to(torch.bool)
mask = torch.from_numpy(mask).to(torch.bool).to(device)
# 2. 生成掩码矩阵
masked_input = target_map.clone()
@ -204,7 +204,7 @@ def train():
map = torch.full((1, MAP_SIZE), MASK_TOKEN).to(device)
for i in range(GENERATE_STEP):
# 1. 预测
logits = model(map, cond) # [1, H * W, num_classes]
logits = model(map, cond, heatmap) # [1, H * W, num_classes]
probs = F.softmax(logits, dim=-1)
# 2. 采样(为了多样性,这里可以使用概率采样而不是取最大值)