mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-19 00:01:13 +08:00
fix: train maskgit
This commit is contained in:
parent
c000b90794
commit
c9bb50d503
@ -42,8 +42,8 @@ from .maskGIT.mask import MapMask
|
||||
# 0. 空地, 1. 墙壁, 2. 门, 3. 钥匙, 4. 红宝石, 5. 蓝宝石, 6. 绿宝石, 7. 血瓶
|
||||
# 8. 道具, 9. 怪物, 10. 入口, 15. 掩码 token
|
||||
|
||||
BATCH_SIZE = 16
|
||||
VAL_BATCH_DIVIDER = 16
|
||||
BATCH_SIZE = 128
|
||||
VAL_BATCH_DIVIDER = 128
|
||||
NUM_CLASSES = 16
|
||||
MASK_TOKEN = 15
|
||||
GENERATE_STEP = 8
|
||||
@ -126,7 +126,7 @@ def train():
|
||||
for i in range(B):
|
||||
mask[i] = masker.mask(H, W)
|
||||
|
||||
mask = torch.from_numpy(mask).to(torch.bool)
|
||||
mask = torch.from_numpy(mask).to(torch.bool).to(device)
|
||||
|
||||
# 掩码
|
||||
masked_input = target_map.clone()
|
||||
@ -178,7 +178,7 @@ def train():
|
||||
for i in range(B):
|
||||
mask[i] = masker.mask(H, W)
|
||||
|
||||
mask = torch.from_numpy(mask).to(torch.bool)
|
||||
mask = torch.from_numpy(mask).to(torch.bool).to(device)
|
||||
|
||||
# 2. 生成掩码矩阵
|
||||
masked_input = target_map.clone()
|
||||
@ -204,7 +204,7 @@ def train():
|
||||
map = torch.full((1, MAP_SIZE), MASK_TOKEN).to(device)
|
||||
for i in range(GENERATE_STEP):
|
||||
# 1. 预测
|
||||
logits = model(map, cond) # [1, H * W, num_classes]
|
||||
logits = model(map, cond, heatmap) # [1, H * W, num_classes]
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
|
||||
# 2. 采样(为了多样性,这里可以使用概率采样而不是取最大值)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user