fix: Diffusion 训练热力图改为 -1-1 范围

This commit is contained in:
unanmed 2026-04-25 16:18:16 +08:00
parent 966d007721
commit b471bb46eb
3 changed files with 10 additions and 9 deletions

View File

@ -9,7 +9,7 @@ class Diffusion:
# cosine schedule推荐
steps = torch.arange(T + 1, dtype=torch.float32)
s = 0.1
f = torch.cos(((steps / T) + s) / (1 + s) * math.pi * 0.5) ** 2
f = torch.cos(((steps / (T + 1)) + s) / (1 + s) * math.pi * 0.5) ** 2
alpha_bar = f / f[0]
self.alpha_bar = alpha_bar.to(device)
@ -51,3 +51,4 @@ class Diffusion:
if __name__ == '__main__':
diff = Diffusion("cpu")
print(diff.sqrt_one_minus_ab)
print(diff.sqrt_ab)

View File

@ -129,7 +129,7 @@ def train():
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
cond_heatmap = batch["cond_heatmap"].to(device)
target_heatmap = batch["target_heatmap"].to(device)
target_heatmap = batch["target_heatmap"].to(device) * 2 - 1
B, C, H, W = target_heatmap.shape
optimizer.zero_grad()
@ -175,7 +175,7 @@ def train():
for batch in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm):
# 1. 验证集验证
cond_heatmap = batch["cond_heatmap"].to(device)
target_heatmap = batch["target_heatmap"].to(device)
target_heatmap = batch["target_heatmap"].to(device) * 2 - 1
B, C, H, W = target_heatmap.shape
t = torch.randint(1, T_DIFFUSION, [B], device=device)
@ -236,8 +236,8 @@ def get_nms_sampling_count():
]
def full_generate(heatmap, maskGIT, cond_heatmap: torch.Tensor, diffusion: Diffusion):
fake_heatmap_cond = diffusion.sample(heatmap, cond_heatmap)
fake_heatmap_uncond = diffusion.sample(heatmap, torch.zeros_like(cond_heatmap))
fake_heatmap_cond = (diffusion.sample(heatmap, cond_heatmap) + 1) / 2
fake_heatmap_uncond = (diffusion.sample(heatmap, torch.zeros_like(cond_heatmap)) + 1) / 2
fake_heatmap = fake_heatmap_uncond + W * (fake_heatmap_uncond - fake_heatmap_cond) # [B, C, H, W]
return maskGIT_generate(maskGIT, cond_heatmap.shape[0], fake_heatmap)

View File

@ -233,7 +233,7 @@ def validate(model, maskgit, diffusion, dataloader, ce_weight: float, tile_dict)
preview_idx = 0
for batch in tqdm(dataloader, desc="Validating", leave=False, disable=disable_tqdm):
cond_heatmap = batch["cond_heatmap"].to(device)
target_heatmap = batch["target_heatmap"].to(device)
target_heatmap = batch["target_heatmap"].to(device) * 2 - 1
target_map = batch["target_map"].to(device)
batch_size, _, map_height, map_width = target_heatmap.shape
@ -244,7 +244,7 @@ def validate(model, maskgit, diffusion, dataloader, ce_weight: float, tile_dict)
pred_noise = model(x_t, cond_heatmap, t)
diffusion_loss = F.mse_loss(pred_noise, noise)
generated_heatmap = predict_x0(diffusion, x_t, pred_noise, t)
generated_heatmap = (predict_x0(diffusion, x_t, pred_noise, t) + 1) / 2
maskgit_loss = maskgit_joint_loss(maskgit, generated_heatmap, target_map)
loss = diffusion_loss + ce_weight * maskgit_loss
@ -325,7 +325,7 @@ def train():
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
cond_heatmap = batch["cond_heatmap"].to(device)
target_heatmap = batch["target_heatmap"].to(device)
target_heatmap = batch["target_heatmap"].to(device) * 2 - 1
target_map = batch["target_map"].to(device)
batch_size = target_heatmap.shape[0]
@ -348,7 +348,7 @@ def train():
if use_unconditional_branch:
pred_noise_for_joint = model(x_t, cond_heatmap, t)
generated_heatmap = predict_x0(diffusion, x_t, pred_noise_for_joint, t)
generated_heatmap = (predict_x0(diffusion, x_t, pred_noise_for_joint, t) + 1) / 2
print(torch.mean(generated_heatmap), torch.std(generated_heatmap), generated_heatmap.shape)
maskgit_loss = maskgit_joint_loss(maskgit, generated_heatmap, target_map)