From b471bb46eb4594620d50d93011e358613e7e8990 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Sat, 25 Apr 2026 16:18:16 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20Diffusion=20=E8=AE=AD=E7=BB=83=E7=83=AD?= =?UTF-8?q?=E5=8A=9B=E5=9B=BE=E6=94=B9=E4=B8=BA=20-1-1=20=E8=8C=83?= =?UTF-8?q?=E5=9B=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/heatmap/diffusion.py | 3 ++- ginka/train_heatmap.py | 8 ++++---- ginka/train_joint.py | 8 ++++---- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/ginka/heatmap/diffusion.py b/ginka/heatmap/diffusion.py index ec31c08..4444f6e 100644 --- a/ginka/heatmap/diffusion.py +++ b/ginka/heatmap/diffusion.py @@ -9,7 +9,7 @@ class Diffusion: # cosine schedule(推荐) steps = torch.arange(T + 1, dtype=torch.float32) s = 0.1 - f = torch.cos(((steps / T) + s) / (1 + s) * math.pi * 0.5) ** 2 + f = torch.cos(((steps / (T + 1)) + s) / (1 + s) * math.pi * 0.5) ** 2 alpha_bar = f / f[0] self.alpha_bar = alpha_bar.to(device) @@ -51,3 +51,4 @@ class Diffusion: if __name__ == '__main__': diff = Diffusion("cpu") print(diff.sqrt_one_minus_ab) + print(diff.sqrt_ab) diff --git a/ginka/train_heatmap.py b/ginka/train_heatmap.py index edec34b..3dfb9f2 100644 --- a/ginka/train_heatmap.py +++ b/ginka/train_heatmap.py @@ -129,7 +129,7 @@ def train(): for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): cond_heatmap = batch["cond_heatmap"].to(device) - target_heatmap = batch["target_heatmap"].to(device) + target_heatmap = batch["target_heatmap"].to(device) * 2 - 1 B, C, H, W = target_heatmap.shape optimizer.zero_grad() @@ -175,7 +175,7 @@ def train(): for batch in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm): # 1. 验证集验证 cond_heatmap = batch["cond_heatmap"].to(device) - target_heatmap = batch["target_heatmap"].to(device) + target_heatmap = batch["target_heatmap"].to(device) * 2 - 1 B, C, H, W = target_heatmap.shape t = torch.randint(1, T_DIFFUSION, [B], device=device) @@ -236,8 +236,8 @@ def get_nms_sampling_count(): ] def full_generate(heatmap, maskGIT, cond_heatmap: torch.Tensor, diffusion: Diffusion): - fake_heatmap_cond = diffusion.sample(heatmap, cond_heatmap) - fake_heatmap_uncond = diffusion.sample(heatmap, torch.zeros_like(cond_heatmap)) + fake_heatmap_cond = (diffusion.sample(heatmap, cond_heatmap) + 1) / 2 + fake_heatmap_uncond = (diffusion.sample(heatmap, torch.zeros_like(cond_heatmap)) + 1) / 2 fake_heatmap = fake_heatmap_uncond + W * (fake_heatmap_uncond - fake_heatmap_cond) # [B, C, H, W] return maskGIT_generate(maskGIT, cond_heatmap.shape[0], fake_heatmap) diff --git a/ginka/train_joint.py b/ginka/train_joint.py index 34f01e7..4f7e4a3 100644 --- a/ginka/train_joint.py +++ b/ginka/train_joint.py @@ -233,7 +233,7 @@ def validate(model, maskgit, diffusion, dataloader, ce_weight: float, tile_dict) preview_idx = 0 for batch in tqdm(dataloader, desc="Validating", leave=False, disable=disable_tqdm): cond_heatmap = batch["cond_heatmap"].to(device) - target_heatmap = batch["target_heatmap"].to(device) + target_heatmap = batch["target_heatmap"].to(device) * 2 - 1 target_map = batch["target_map"].to(device) batch_size, _, map_height, map_width = target_heatmap.shape @@ -244,7 +244,7 @@ def validate(model, maskgit, diffusion, dataloader, ce_weight: float, tile_dict) pred_noise = model(x_t, cond_heatmap, t) diffusion_loss = F.mse_loss(pred_noise, noise) - generated_heatmap = predict_x0(diffusion, x_t, pred_noise, t) + generated_heatmap = (predict_x0(diffusion, x_t, pred_noise, t) + 1) / 2 maskgit_loss = maskgit_joint_loss(maskgit, generated_heatmap, target_map) loss = diffusion_loss + ce_weight * maskgit_loss @@ -325,7 +325,7 @@ def train(): for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): cond_heatmap = batch["cond_heatmap"].to(device) - target_heatmap = batch["target_heatmap"].to(device) + target_heatmap = batch["target_heatmap"].to(device) * 2 - 1 target_map = batch["target_map"].to(device) batch_size = target_heatmap.shape[0] @@ -348,7 +348,7 @@ def train(): if use_unconditional_branch: pred_noise_for_joint = model(x_t, cond_heatmap, t) - generated_heatmap = predict_x0(diffusion, x_t, pred_noise_for_joint, t) + generated_heatmap = (predict_x0(diffusion, x_t, pred_noise_for_joint, t) + 1) / 2 print(torch.mean(generated_heatmap), torch.std(generated_heatmap), generated_heatmap.shape) maskgit_loss = maskgit_joint_loss(maskgit, generated_heatmap, target_map)