From 3e898dc5ba6394be04e767ba6bf33bf2131d96e7 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Tue, 7 Apr 2026 22:56:41 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E8=B0=83=E6=95=B4=20Diffusion=20?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/heatmap/cond.py | 2 +- ginka/heatmap/diffusion.py | 79 +++++++++++++++++++++----------------- ginka/heatmap/model.py | 18 +++++++-- ginka/train_heatmap.py | 19 ++++++--- 4 files changed, 72 insertions(+), 46 deletions(-) diff --git a/ginka/heatmap/cond.py b/ginka/heatmap/cond.py index 3d39198..9e71a69 100644 --- a/ginka/heatmap/cond.py +++ b/ginka/heatmap/cond.py @@ -40,7 +40,7 @@ class HeatmapCond(nn.Module): def forward(self, heatmap: torch.Tensor, t: torch.Tensor): # heatmap: [B, C, H, W] - # t: [B, 1] + # t: [B] t_embed = self.time_embedding(t) x = self.conv1(heatmap) + self.fc1(t_embed).unsqueeze(1).unsqueeze(1).permute(0, 3, 1, 2) x = self.conv2(x) + self.fc2(t_embed).unsqueeze(1).unsqueeze(1).permute(0, 3, 1, 2) diff --git a/ginka/heatmap/diffusion.py b/ginka/heatmap/diffusion.py index 27afa5e..fc8e5c8 100644 --- a/ginka/heatmap/diffusion.py +++ b/ginka/heatmap/diffusion.py @@ -2,49 +2,56 @@ import math import torch class Diffusion: - def __init__(self, device, T=100): + def __init__(self, device, T=100, min_beta=0.0001, max_beta=0.02): self.T = T self.device = device - # cosine schedule(推荐) - steps = torch.arange(T + 1, dtype=torch.float32) - s = 0.008 - f = torch.cos(((steps / T) + s) / (1 + s) * math.pi * 0.5) ** 2 - alpha_bar = f / f[0] - - self.alpha_bar = alpha_bar.to(device) - self.sqrt_ab = torch.sqrt(self.alpha_bar) - self.sqrt_one_minus_ab = torch.sqrt(1 - self.alpha_bar) + betas = torch.linspace(min_beta, max_beta, T).to(device) + alphas = 1 - betas + alpha_bars = torch.empty_like(alphas) + product = 1 + for i, alpha in enumerate(alphas): + product *= alpha + alpha_bars[i] = product + self.betas = betas + self.n_steps = T + self.alphas = alphas + self.alpha_bars = alpha_bars def q_sample(self, x0, t, noise): """ 前向加噪 """ - return ( - self.sqrt_ab[t][:, None, None, None] * x0 - + self.sqrt_one_minus_ab[t][:, None, None, None] * noise - ) + alpha_bar = self.alpha_bars[t].reshape(-1, 1, 1, 1) + res = noise * torch.sqrt(1 - alpha_bar) + torch.sqrt(alpha_bar) * x0 + return res - def sample(self, model, cond: torch.Tensor, steps=20): - B = cond.shape[0] - x = torch.randn_like(cond).to(cond.device) - - step_size = self.T // steps - - for i in reversed(range(0, self.T, step_size)): - t = torch.full((B,), i, device=cond.device) - - pred_noise = model(x, cond, t) - - alpha = self.alpha_bar[i] - alpha_prev = self.alpha_bar[max(i - step_size, 0)] - - x0_pred = (x - torch.sqrt(1 - alpha) * pred_noise) / torch.sqrt(alpha) - - x = ( - torch.sqrt(alpha_prev) * x0_pred - + torch.sqrt(1 - alpha_prev) * pred_noise - ) - + def sample(self, model, cond: torch.Tensor): + x = torch.randn_like(cond).to(self.device) + for t in range(self.n_steps - 1, -1, -1): + x = self.sample_backward_step(x, t, model) return x - \ No newline at end of file + + def sample_backward_step(self, x_t, t, cond, model): + B = x_t.shape[0] + t_tensor = torch.tensor([t] * B, dtype=torch.long).to(self.device) + eps = model(x_t, cond, t_tensor) + + if t == 0: + noise = 0 + else: + var = (1 - self.alpha_bars[t - 1]) / (1 - self.alpha_bars[t]) * self.betas[t] + noise = torch.randn_like(x_t) + noise *= torch.sqrt(var) + + mean = (x_t - + (1 - self.alphas[t]) / torch.sqrt(1 - self.alpha_bars[t]) * + eps) / torch.sqrt(self.alphas[t]) + x_t = mean + noise + + return x_t + +if __name__ == '__main__': + diff = Diffusion("cpu") + print(diff.alphas) + print(diff.alpha_bars) diff --git a/ginka/heatmap/model.py b/ginka/heatmap/model.py index 9e29e5f..474a93c 100644 --- a/ginka/heatmap/model.py +++ b/ginka/heatmap/model.py @@ -16,8 +16,14 @@ class GinkaHeatmapModel(nn.Module): self.cond = HeatmapCond(T, embed_dim=embed_dim, heatmap_dim=heatmap_dim, output_dim=d_model) self.input = HeatmapCond(T, embed_dim=embed_dim, heatmap_dim=heatmap_dim, output_dim=d_model) self.transformer = MaskGIT(d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers) + self.cross_attn = nn.MultiheadAttention(d_model, num_heads=nhead, batch_first=True) self.output_fc = nn.Sequential( - nn.Linear(d_model, heatmap_dim) + nn.Linear(d_model, d_model // 2), + nn.LayerNorm(d_model // 2), + nn.Dropout(0.3), + nn.GELU(), + + nn.Linear(d_model // 2, heatmap_dim) ) def forward(self, input: torch.Tensor, cond: torch.Tensor, t: torch.Tensor): @@ -26,11 +32,15 @@ class GinkaHeatmapModel(nn.Module): # t: [B, 1] input = self.input(input, t) # [B, d_model, H, W] cond = self.cond(cond, t) # [B, d_model, H, W] - hidden = input + cond - B, C, H, W = hidden.shape + B, C, H, W = cond.shape + cond_tokens = cond.view(B, C, H * W).permute(0, 2, 1) # [B, H * W, d_model] + scale = torch.sigmoid(cond) + hidden = input * (1 + scale) + cond hidden = hidden.view(B, C, H * W).permute(0, 2, 1) # [B, H * W, d_model] hidden = hidden + self.pos_embedding hidden = self.transformer(hidden) # [B, H * W, d_model] + attn, _ = self.cross_attn(hidden, cond_tokens, cond_tokens) + hidden = hidden + attn output = self.output_fc(hidden) # [B, H * W, heatmap_dim] return output.view(B, H, W, self.heatmap_dim).permute(0, 3, 1, 2) @@ -39,7 +49,7 @@ if __name__ == "__main__": input = torch.randn(1, 9, 13, 13).to(device) cond = torch.randint(0, 1, [1, 9, 13, 13]).to(device) - t = torch.randint(0, 100, [1, 1]).to(device) + t = torch.randint(0, 100, [1]).to(device) # 初始化模型 model = GinkaHeatmapModel(heatmap_dim=9).to(device) diff --git a/ginka/train_heatmap.py b/ginka/train_heatmap.py index 06898d8..81d9601 100644 --- a/ginka/train_heatmap.py +++ b/ginka/train_heatmap.py @@ -49,6 +49,7 @@ T_DIFFUSION = 100 MIN_MASK = 0 MAX_MASK = 0.8 NOISE_SAMPLING_K = [40, 15, 21, 8, 8, 4, 1, 2, 10] +W = 5 # CFG 参数 device = torch.device( "cuda:1" if torch.cuda.is_available() @@ -131,11 +132,15 @@ def train(): target_heatmap = batch["target_heatmap"].to(device) B, C, H, W = target_heatmap.shape - t = torch.randint(1, T_DIFFUSION, (B,), device=device) + t = torch.randint(1, T_DIFFUSION, [B], device=device) noise = torch.randn_like(target_heatmap) x_t = diffusion.q_sample(target_heatmap, t, noise) + # CFG 随机概率没有输入条件 + if np.random.rand() < 0.2: + cond_heatmap = torch.zeros_like(cond_heatmap) + pred_noise = model(x_t, cond_heatmap, t) loss = F.mse_loss(pred_noise, noise) @@ -185,8 +190,7 @@ def train(): # 2. 从头完整生成,并使用训练好的 MaskGIT 生成地图 if args.use_maskgit: - fake_heatmap = diffusion.sample(model, cond_heatmap) - map = maskGIT_generate(maskGIT, B, fake_heatmap) + map = full_generate(model, maskGIT, cond_heatmap, diffusion) generated_img = matrix_to_image_cv(map.view(B, H, W)[0].cpu().numpy(), tile_dict) cv2.imwrite(f"result/final_img/{idx}.png", generated_img) @@ -199,8 +203,7 @@ def train(): noise = generate_fractal_noise_2d((16, 16), (4, 4), 1)[0:MAP_H,0:MAP_W] ar[0,c] = nms_sampling(noise, NOISE_SAMPLING_K[c]) - fake_heatmap = diffusion.sample(model, torch.FloatTensor(ar).to(device)) - map = maskGIT_generate(maskGIT, B, fake_heatmap) + map = full_generate(model, maskGIT, torch.FloatTensor(ar).to(device), diffusion) generated_img = matrix_to_image_cv(map.view(1, H, W)[0].cpu().numpy(), tile_dict) cv2.imwrite(f"result/final_img/g-{i}.png", generated_img) @@ -215,6 +218,12 @@ def train(): "model_state": maskGIT.state_dict(), }, f"result/ginka_heatmap.pth") +def full_generate(heatmap, maskGIT, cond_heatmap: torch.Tensor, diffusion: Diffusion): + fake_heatmap_cond = diffusion.sample(heatmap, cond_heatmap) + fake_heatmap_uncond = diffusion.sample(heatmap, torch.zeros_like(cond_heatmap)) + fake_heatmap = fake_heatmap_uncond + W * (fake_heatmap_uncond - fake_heatmap_cond) + return maskGIT_generate(maskGIT, cond_heatmap.shape[0], fake_heatmap) + def maskGIT_generate(maskGIT, B: int, heatmap: torch.Tensor): map = torch.full((B, MAP_H * MAP_W), MASK_TOKEN).to(device) for i in range(GENERATE_STEP):