From 4059f0e05adbd8cd90bd65cbcca773378f0fd5b0 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Fri, 10 Apr 2026 12:49:21 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E8=B0=83=E6=95=B4=E8=AE=AD=E7=BB=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_heatmap.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ginka/train_heatmap.py b/ginka/train_heatmap.py index e9c5775..546688d 100644 --- a/ginka/train_heatmap.py +++ b/ginka/train_heatmap.py @@ -235,14 +235,12 @@ def get_nms_sampling_count(): np.random.randint(2, 10) ] -@torch.no_grad() def full_generate(heatmap, maskGIT, cond_heatmap: torch.Tensor, diffusion: Diffusion): fake_heatmap_cond = diffusion.sample(heatmap, cond_heatmap) fake_heatmap_uncond = diffusion.sample(heatmap, torch.zeros_like(cond_heatmap)) fake_heatmap = fake_heatmap_uncond + W * (fake_heatmap_uncond - fake_heatmap_cond) return maskGIT_generate(maskGIT, cond_heatmap.shape[0], fake_heatmap) -@torch.no_grad() def maskGIT_generate(maskGIT, B: int, heatmap: torch.Tensor): map = torch.full((B, MAP_H * MAP_W), MASK_TOKEN).to(device) for i in range(GENERATE_STEP):