From 36265a9bce97cb60bde0be2cdca070a7a99563ed Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Wed, 11 Mar 2026 22:57:47 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E9=AA=8C=E8=AF=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_maskGIT.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ginka/train_maskGIT.py b/ginka/train_maskGIT.py index 0903464..26d0df7 100644 --- a/ginka/train_maskGIT.py +++ b/ginka/train_maskGIT.py @@ -196,7 +196,7 @@ def train(): idx += 1 # 2. 从头完整生成 - map = torch.full((1, MAP_SIZE), MASK_TOKEN).to(device) + map = torch.full((B, MAP_SIZE), MASK_TOKEN).to(device) for i in range(GENERATE_STEP): # 1. 预测 logits = model(map, cond, heatmap) # [1, H * W, num_classes] @@ -223,7 +223,7 @@ def train(): if (map == MASK_TOKEN).sum() == 0: break - generated_img = matrix_to_image_cv(map.view(1, H, W)[0].cpu().numpy(), tile_dict) + generated_img = matrix_to_image_cv(map.view(B, H, W)[0].cpu().numpy(), tile_dict) cv2.imwrite(f"result/transformer_img/g-{idx}.png", generated_img) avg_loss_val = val_loss_total.item() / len(dataloader_val)