diff --git a/ginka/train_heatmap.py b/ginka/train_heatmap.py index e9c5775..546688d 100644 --- a/ginka/train_heatmap.py +++ b/ginka/train_heatmap.py @@ -235,14 +235,12 @@ def get_nms_sampling_count(): np.random.randint(2, 10) ] -@torch.no_grad() def full_generate(heatmap, maskGIT, cond_heatmap: torch.Tensor, diffusion: Diffusion): fake_heatmap_cond = diffusion.sample(heatmap, cond_heatmap) fake_heatmap_uncond = diffusion.sample(heatmap, torch.zeros_like(cond_heatmap)) fake_heatmap = fake_heatmap_uncond + W * (fake_heatmap_uncond - fake_heatmap_cond) return maskGIT_generate(maskGIT, cond_heatmap.shape[0], fake_heatmap) -@torch.no_grad() def maskGIT_generate(maskGIT, B: int, heatmap: torch.Tensor): map = torch.full((B, MAP_H * MAP_W), MASK_TOKEN).to(device) for i in range(GENERATE_STEP):