chore: 调整训练

This commit is contained in:
unanmed 2026-04-10 12:49:21 +08:00
parent 54164b9f22
commit 4059f0e05a

View File

@ -235,14 +235,12 @@ def get_nms_sampling_count():
np.random.randint(2, 10)
]
@torch.no_grad()
def full_generate(heatmap, maskGIT, cond_heatmap: torch.Tensor, diffusion: Diffusion):
fake_heatmap_cond = diffusion.sample(heatmap, cond_heatmap)
fake_heatmap_uncond = diffusion.sample(heatmap, torch.zeros_like(cond_heatmap))
fake_heatmap = fake_heatmap_uncond + W * (fake_heatmap_uncond - fake_heatmap_cond)
return maskGIT_generate(maskGIT, cond_heatmap.shape[0], fake_heatmap)
@torch.no_grad()
def maskGIT_generate(maskGIT, B: int, heatmap: torch.Tensor):
map = torch.full((B, MAP_H * MAP_W), MASK_TOKEN).to(device)
for i in range(GENERATE_STEP):