mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
chore: 调整训练
This commit is contained in:
parent
54164b9f22
commit
4059f0e05a
@ -235,14 +235,12 @@ def get_nms_sampling_count():
|
||||
np.random.randint(2, 10)
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def full_generate(heatmap, maskGIT, cond_heatmap: torch.Tensor, diffusion: Diffusion):
|
||||
fake_heatmap_cond = diffusion.sample(heatmap, cond_heatmap)
|
||||
fake_heatmap_uncond = diffusion.sample(heatmap, torch.zeros_like(cond_heatmap))
|
||||
fake_heatmap = fake_heatmap_uncond + W * (fake_heatmap_uncond - fake_heatmap_cond)
|
||||
return maskGIT_generate(maskGIT, cond_heatmap.shape[0], fake_heatmap)
|
||||
|
||||
@torch.no_grad()
|
||||
def maskGIT_generate(maskGIT, B: int, heatmap: torch.Tensor):
|
||||
map = torch.full((B, MAP_H * MAP_W), MASK_TOKEN).to(device)
|
||||
for i in range(GENERATE_STEP):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user