mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-23 12:21:11 +08:00
fix: train maskgit
This commit is contained in:
parent
c000b90794
commit
c9bb50d503
@ -42,8 +42,8 @@ from .maskGIT.mask import MapMask
|
|||||||
# 0. 空地, 1. 墙壁, 2. 门, 3. 钥匙, 4. 红宝石, 5. 蓝宝石, 6. 绿宝石, 7. 血瓶
|
# 0. 空地, 1. 墙壁, 2. 门, 3. 钥匙, 4. 红宝石, 5. 蓝宝石, 6. 绿宝石, 7. 血瓶
|
||||||
# 8. 道具, 9. 怪物, 10. 入口, 15. 掩码 token
|
# 8. 道具, 9. 怪物, 10. 入口, 15. 掩码 token
|
||||||
|
|
||||||
BATCH_SIZE = 16
|
BATCH_SIZE = 128
|
||||||
VAL_BATCH_DIVIDER = 16
|
VAL_BATCH_DIVIDER = 128
|
||||||
NUM_CLASSES = 16
|
NUM_CLASSES = 16
|
||||||
MASK_TOKEN = 15
|
MASK_TOKEN = 15
|
||||||
GENERATE_STEP = 8
|
GENERATE_STEP = 8
|
||||||
@ -126,7 +126,7 @@ def train():
|
|||||||
for i in range(B):
|
for i in range(B):
|
||||||
mask[i] = masker.mask(H, W)
|
mask[i] = masker.mask(H, W)
|
||||||
|
|
||||||
mask = torch.from_numpy(mask).to(torch.bool)
|
mask = torch.from_numpy(mask).to(torch.bool).to(device)
|
||||||
|
|
||||||
# 掩码
|
# 掩码
|
||||||
masked_input = target_map.clone()
|
masked_input = target_map.clone()
|
||||||
@ -178,7 +178,7 @@ def train():
|
|||||||
for i in range(B):
|
for i in range(B):
|
||||||
mask[i] = masker.mask(H, W)
|
mask[i] = masker.mask(H, W)
|
||||||
|
|
||||||
mask = torch.from_numpy(mask).to(torch.bool)
|
mask = torch.from_numpy(mask).to(torch.bool).to(device)
|
||||||
|
|
||||||
# 2. 生成掩码矩阵
|
# 2. 生成掩码矩阵
|
||||||
masked_input = target_map.clone()
|
masked_input = target_map.clone()
|
||||||
@ -204,7 +204,7 @@ def train():
|
|||||||
map = torch.full((1, MAP_SIZE), MASK_TOKEN).to(device)
|
map = torch.full((1, MAP_SIZE), MASK_TOKEN).to(device)
|
||||||
for i in range(GENERATE_STEP):
|
for i in range(GENERATE_STEP):
|
||||||
# 1. 预测
|
# 1. 预测
|
||||||
logits = model(map, cond) # [1, H * W, num_classes]
|
logits = model(map, cond, heatmap) # [1, H * W, num_classes]
|
||||||
probs = F.softmax(logits, dim=-1)
|
probs = F.softmax(logits, dim=-1)
|
||||||
|
|
||||||
# 2. 采样(为了多样性,这里可以使用概率采样而不是取最大值)
|
# 2. 采样(为了多样性,这里可以使用概率采样而不是取最大值)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user