diff --git a/ginka/heatmap/diffusion.py b/ginka/heatmap/diffusion.py index 4444f6e..6e3a53e 100644 --- a/ginka/heatmap/diffusion.py +++ b/ginka/heatmap/diffusion.py @@ -2,9 +2,10 @@ import math import torch class Diffusion: - def __init__(self, device, T=100): + def __init__(self, device, T=100, noise_scale=0.5): self.T = T self.device = device + self.noise_scale = noise_scale # cosine schedule(推荐) steps = torch.arange(T + 1, dtype=torch.float32) @@ -18,33 +19,41 @@ class Diffusion: def q_sample(self, x0, t, noise): """ - 前向加噪 + 前向加噪:x_t = sqrt(αbar_t) * x0 + sqrt(1-αbar_t) * noise_scale * ε + noise_scale 降低噪声功率,使信号不被淹没 """ return ( self.sqrt_ab[t][:, None, None, None] * x0 - + self.sqrt_one_minus_ab[t][:, None, None, None] * noise + + self.sqrt_one_minus_ab[t][:, None, None, None] * noise * self.noise_scale ) def sample(self, model, cond: torch.Tensor, steps=20): + """ + DDIM 风格逆向采样,模型预测 x_0 + x_{t-1} = sqrt(αbar_{t-1}) * x0_pred + + sqrt(1-αbar_{t-1}) / sqrt(1-αbar_t) * (x_t - sqrt(αbar_t) * x0_pred) + """ B = cond.shape[0] - x = torch.randn_like(cond).to(cond.device) + # 初始噪声与前向过程保持一致的噪声功率 + x = torch.randn_like(cond).to(cond.device) * self.noise_scale step_size = self.T // steps for i in reversed(range(0, self.T, step_size)): t = torch.full((B,), i, device=cond.device) - pred_noise = model(x, cond, t) + # 模型直接预测 x_0 + x0_pred = model(x, cond, t) alpha = self.alpha_bar[i] alpha_prev = self.alpha_bar[max(i - step_size, 0)] - x0_pred = (x - torch.sqrt(1 - alpha) * pred_noise) / torch.sqrt(alpha) + # DDIM x0-prediction 更新 + direction = ( + torch.sqrt(1 - alpha_prev) / torch.sqrt(1 - alpha) + ) * (x - torch.sqrt(alpha) * x0_pred) - x = ( - torch.sqrt(alpha_prev) * x0_pred - + torch.sqrt(1 - alpha_prev) * pred_noise - ) + x = torch.sqrt(alpha_prev) * x0_pred + direction return x diff --git a/ginka/train_heatmap.py b/ginka/train_heatmap.py index 3dfb9f2..97c65ac 100644 --- a/ginka/train_heatmap.py +++ b/ginka/train_heatmap.py @@ -48,6 +48,7 @@ D_MODEL_DIFFUSION = 128 T_DIFFUSION = 100 MIN_MASK = 0 MAX_MASK = 1 +NOISE_SCALE = 0.3 W = 5 # CFG 参数 device = torch.device( @@ -91,7 +92,7 @@ def train(): num_layers=NUM_LAYERS_DIFFUSION ).to(device) - diffusion = Diffusion(device) + diffusion = Diffusion(device, noise_scale=NOISE_SCALE) dataset = GinkaHeatmapDataset(args.train, min_mask=MIN_MASK, max_mask=MAX_MASK) dataset_val = GinkaHeatmapDataset(args.validate, min_mask=MIN_MASK, max_mask=MAX_MASK) @@ -129,7 +130,7 @@ def train(): for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): cond_heatmap = batch["cond_heatmap"].to(device) - target_heatmap = batch["target_heatmap"].to(device) * 2 - 1 + target_heatmap = batch["target_heatmap"].to(device) B, C, H, W = target_heatmap.shape optimizer.zero_grad() @@ -143,9 +144,10 @@ def train(): if np.random.rand() < 0.2: cond_heatmap = torch.zeros_like(cond_heatmap) - pred_noise = model(x_t, cond_heatmap, t) + # 模型预测 x_0,损失直接对齐热力图 + pred_x0 = model(x_t, cond_heatmap, t) - loss = F.mse_loss(pred_noise, noise) + loss = F.mse_loss(pred_x0, target_heatmap) loss.backward() optimizer.step() @@ -175,7 +177,7 @@ def train(): for batch in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm): # 1. 验证集验证 cond_heatmap = batch["cond_heatmap"].to(device) - target_heatmap = batch["target_heatmap"].to(device) * 2 - 1 + target_heatmap = batch["target_heatmap"].to(device) B, C, H, W = target_heatmap.shape t = torch.randint(1, T_DIFFUSION, [B], device=device) @@ -183,9 +185,9 @@ def train(): x_t = diffusion.q_sample(target_heatmap, t, noise) - pred_noise = model(x_t, cond_heatmap, t) + pred_x0 = model(x_t, cond_heatmap, t) - loss = F.mse_loss(pred_noise, noise) + loss = F.mse_loss(pred_x0, target_heatmap) val_loss_total += loss.detach() @@ -236,8 +238,8 @@ def get_nms_sampling_count(): ] def full_generate(heatmap, maskGIT, cond_heatmap: torch.Tensor, diffusion: Diffusion): - fake_heatmap_cond = (diffusion.sample(heatmap, cond_heatmap) + 1) / 2 - fake_heatmap_uncond = (diffusion.sample(heatmap, torch.zeros_like(cond_heatmap)) + 1) / 2 + fake_heatmap_cond = diffusion.sample(heatmap, cond_heatmap) + fake_heatmap_uncond = diffusion.sample(heatmap, torch.zeros_like(cond_heatmap)) fake_heatmap = fake_heatmap_uncond + W * (fake_heatmap_uncond - fake_heatmap_cond) # [B, C, H, W] return maskGIT_generate(maskGIT, cond_heatmap.shape[0], fake_heatmap) diff --git a/ginka/train_joint.py b/ginka/train_joint.py index 4f7e4a3..96e29c7 100644 --- a/ginka/train_joint.py +++ b/ginka/train_joint.py @@ -46,6 +46,7 @@ D_MODEL_DIFFUSION = 128 T_DIFFUSION = 100 MIN_MASK = 0 MAX_MASK = 1 +NOISE_SCALE = 0.3 # 验证预览配置 PREVIEW_CFG_WEIGHT = 5 # 预览生成时使用的 CFG 强度 @@ -102,14 +103,6 @@ def freeze_module(module: torch.nn.Module): parameter.requires_grad = False -def predict_x0(diffusion: Diffusion, x_t: torch.Tensor, pred_noise: torch.Tensor, t: torch.Tensor): - # 根据当前时刻的噪声预测还原 x0 热力图估计。 - sqrt_ab = diffusion.sqrt_ab[t][:, None, None, None] - sqrt_one_minus_ab = diffusion.sqrt_one_minus_ab[t][:, None, None, None] - x0 = (x_t - sqrt_one_minus_ab * pred_noise) / sqrt_ab - return x0 - - def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: torch.Tensor): # 用冻结的 MaskGIT 对 Diffusion 生成的热力图施加地图级监督。 batch_size, height, width = target_map.shape @@ -233,7 +226,7 @@ def validate(model, maskgit, diffusion, dataloader, ce_weight: float, tile_dict) preview_idx = 0 for batch in tqdm(dataloader, desc="Validating", leave=False, disable=disable_tqdm): cond_heatmap = batch["cond_heatmap"].to(device) - target_heatmap = batch["target_heatmap"].to(device) * 2 - 1 + target_heatmap = batch["target_heatmap"].to(device) target_map = batch["target_map"].to(device) batch_size, _, map_height, map_width = target_heatmap.shape @@ -241,11 +234,10 @@ def validate(model, maskgit, diffusion, dataloader, ce_weight: float, tile_dict) noise = torch.randn_like(target_heatmap) x_t = diffusion.q_sample(target_heatmap, t, noise) - pred_noise = model(x_t, cond_heatmap, t) - diffusion_loss = F.mse_loss(pred_noise, noise) + pred_x0 = model(x_t, cond_heatmap, t) + diffusion_loss = F.mse_loss(pred_x0, target_heatmap) - generated_heatmap = (predict_x0(diffusion, x_t, pred_noise, t) + 1) / 2 - maskgit_loss = maskgit_joint_loss(maskgit, generated_heatmap, target_map) + maskgit_loss = maskgit_joint_loss(maskgit, pred_x0, target_map) loss = diffusion_loss + ce_weight * maskgit_loss total_loss += loss.item() @@ -297,7 +289,7 @@ def train(): d_model=D_MODEL_DIFFUSION, num_layers=NUM_LAYERS_DIFFUSION, ).to(device) - diffusion = Diffusion(device, T=T_DIFFUSION) + diffusion = Diffusion(device, T=T_DIFFUSION, noise_scale=NOISE_SCALE) dataset = GinkaJointDataset(args.train, min_mask=MIN_MASK, max_mask=MAX_MASK) dataset_val = GinkaJointDataset(args.validate, min_mask=MIN_MASK, max_mask=MAX_MASK) @@ -325,7 +317,7 @@ def train(): for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): cond_heatmap = batch["cond_heatmap"].to(device) - target_heatmap = batch["target_heatmap"].to(device) * 2 - 1 + target_heatmap = batch["target_heatmap"].to(device) target_map = batch["target_map"].to(device) batch_size = target_heatmap.shape[0] @@ -341,16 +333,15 @@ def train(): cond_for_diffusion = torch.zeros_like(cond_heatmap) use_unconditional_branch = True - pred_noise = model(x_t, cond_for_diffusion, t) - diffusion_loss = F.mse_loss(pred_noise, noise) + pred_x0 = model(x_t, cond_for_diffusion, t) + diffusion_loss = F.mse_loss(pred_x0, target_heatmap) - pred_noise_for_joint = pred_noise + # 若使用无条件分支,重新对有条件输入预测以计算联合损失 + pred_x0_for_joint = pred_x0 if use_unconditional_branch: - pred_noise_for_joint = model(x_t, cond_heatmap, t) + pred_x0_for_joint = model(x_t, cond_heatmap, t) - generated_heatmap = (predict_x0(diffusion, x_t, pred_noise_for_joint, t) + 1) / 2 - print(torch.mean(generated_heatmap), torch.std(generated_heatmap), generated_heatmap.shape) - maskgit_loss = maskgit_joint_loss(maskgit, generated_heatmap, target_map) + maskgit_loss = maskgit_joint_loss(maskgit, pred_x0_for_joint, target_map) loss = diffusion_loss + CE_WEIGHT * maskgit_loss loss.backward()