diff --git a/ginka/heatmap/diffusion.py b/ginka/heatmap/diffusion.py index 3705c99..ec31c08 100644 --- a/ginka/heatmap/diffusion.py +++ b/ginka/heatmap/diffusion.py @@ -2,56 +2,52 @@ import math import torch class Diffusion: - def __init__(self, device, T=100, min_beta=0.0001, max_beta=0.01): + def __init__(self, device, T=100): self.T = T self.device = device - betas = torch.linspace(min_beta, max_beta, T).to(device) - alphas = 1 - betas - alpha_bars = torch.empty_like(alphas) - product = 1 - for i, alpha in enumerate(alphas): - product *= alpha - alpha_bars[i] = product - self.betas = betas - self.n_steps = T - self.alphas = alphas - self.alpha_bars = alpha_bars + # cosine schedule(推荐) + steps = torch.arange(T + 1, dtype=torch.float32) + s = 0.1 + f = torch.cos(((steps / T) + s) / (1 + s) * math.pi * 0.5) ** 2 + alpha_bar = f / f[0] + + self.alpha_bar = alpha_bar.to(device) + self.sqrt_ab = torch.sqrt(self.alpha_bar) + self.sqrt_one_minus_ab = torch.sqrt(1 - self.alpha_bar) def q_sample(self, x0, t, noise): """ 前向加噪 """ - alpha_bar = self.alpha_bars[t].reshape(-1, 1, 1, 1) - res = noise * torch.sqrt(1 - alpha_bar) + torch.sqrt(alpha_bar) * x0 - return res + return ( + self.sqrt_ab[t][:, None, None, None] * x0 + + self.sqrt_one_minus_ab[t][:, None, None, None] * noise + ) - def sample(self, model, cond: torch.Tensor): - x = torch.randn_like(cond).to(self.device) - for t in range(self.n_steps - 1, -1, -1): - x = self.sample_backward_step(x, t, cond, model) + def sample(self, model, cond: torch.Tensor, steps=20): + B = cond.shape[0] + x = torch.randn_like(cond).to(cond.device) + + step_size = self.T // steps + + for i in reversed(range(0, self.T, step_size)): + t = torch.full((B,), i, device=cond.device) + + pred_noise = model(x, cond, t) + + alpha = self.alpha_bar[i] + alpha_prev = self.alpha_bar[max(i - step_size, 0)] + + x0_pred = (x - torch.sqrt(1 - alpha) * pred_noise) / torch.sqrt(alpha) + + x = ( + torch.sqrt(alpha_prev) * x0_pred + + torch.sqrt(1 - alpha_prev) * pred_noise + ) + return x - def sample_backward_step(self, x_t, t, cond, model): - B = x_t.shape[0] - t_tensor = torch.tensor([t] * B, dtype=torch.long).to(self.device) - eps = model(x_t, cond, t_tensor) - - if t == 0: - noise = 0 - else: - var = (1 - self.alpha_bars[t - 1]) / (1 - self.alpha_bars[t]) * self.betas[t] - noise = torch.randn_like(x_t) - noise *= torch.sqrt(var) - - mean = (x_t - - (1 - self.alphas[t]) / torch.sqrt(1 - self.alpha_bars[t]) * - eps) / torch.sqrt(self.alphas[t]) - x_t = mean + noise - - return x_t - if __name__ == '__main__': diff = Diffusion("cpu") - print(diff.alphas) - print(diff.alpha_bars) + print(diff.sqrt_one_minus_ab)