mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-19 16:41:10 +08:00
fix: 验证
This commit is contained in:
parent
a0faada62b
commit
36265a9bce
@ -196,7 +196,7 @@ def train():
|
|||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
# 2. 从头完整生成
|
# 2. 从头完整生成
|
||||||
map = torch.full((1, MAP_SIZE), MASK_TOKEN).to(device)
|
map = torch.full((B, MAP_SIZE), MASK_TOKEN).to(device)
|
||||||
for i in range(GENERATE_STEP):
|
for i in range(GENERATE_STEP):
|
||||||
# 1. 预测
|
# 1. 预测
|
||||||
logits = model(map, cond, heatmap) # [1, H * W, num_classes]
|
logits = model(map, cond, heatmap) # [1, H * W, num_classes]
|
||||||
@ -223,7 +223,7 @@ def train():
|
|||||||
if (map == MASK_TOKEN).sum() == 0:
|
if (map == MASK_TOKEN).sum() == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
generated_img = matrix_to_image_cv(map.view(1, H, W)[0].cpu().numpy(), tile_dict)
|
generated_img = matrix_to_image_cv(map.view(B, H, W)[0].cpu().numpy(), tile_dict)
|
||||||
cv2.imwrite(f"result/transformer_img/g-{idx}.png", generated_img)
|
cv2.imwrite(f"result/transformer_img/g-{idx}.png", generated_img)
|
||||||
|
|
||||||
avg_loss_val = val_loss_total.item() / len(dataloader_val)
|
avg_loss_val = val_loss_total.item() / len(dataloader_val)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user