import math import torch class Diffusion: def __init__(self, device, T=100): self.T = T self.device = device # cosine schedule(推荐) steps = torch.arange(T + 1, dtype=torch.float32) s = 0.008 f = torch.cos(((steps / T) + s) / (1 + s) * math.pi * 0.5) ** 2 alpha_bar = f / f[0] self.alpha_bar = alpha_bar.to(device) self.sqrt_ab = torch.sqrt(self.alpha_bar) self.sqrt_one_minus_ab = torch.sqrt(1 - self.alpha_bar) def q_sample(self, x0, t, noise): """ 前向加噪 """ return ( self.sqrt_ab[t][:, None, None, None] * x0 + self.sqrt_one_minus_ab[t][:, None, None, None] * noise ) def sample(self, model, cond: torch.Tensor, steps=20): B = cond.shape[0] x = torch.randn_like(cond).to(cond.device) step_size = self.T // steps for i in reversed(range(0, self.T, step_size)): t = torch.full((B,), i, device=cond.device) pred_noise = model(x, cond, t) alpha = self.alpha_bar[i] alpha_prev = self.alpha_bar[max(i - step_size, 0)] x0_pred = (x - torch.sqrt(1 - alpha) * pred_noise) / torch.sqrt(alpha) x = ( torch.sqrt(alpha_prev) * x0_pred + torch.sqrt(1 - alpha_prev) * pred_noise ) return x