mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-17 06:51:11 +08:00
chore: 换回第一版 diffusion 策略
This commit is contained in:
parent
1237d45d95
commit
54164b9f22
@ -2,56 +2,52 @@ import math
|
||||
import torch
|
||||
|
||||
class Diffusion:
|
||||
def __init__(self, device, T=100, min_beta=0.0001, max_beta=0.01):
|
||||
def __init__(self, device, T=100):
|
||||
self.T = T
|
||||
self.device = device
|
||||
|
||||
betas = torch.linspace(min_beta, max_beta, T).to(device)
|
||||
alphas = 1 - betas
|
||||
alpha_bars = torch.empty_like(alphas)
|
||||
product = 1
|
||||
for i, alpha in enumerate(alphas):
|
||||
product *= alpha
|
||||
alpha_bars[i] = product
|
||||
self.betas = betas
|
||||
self.n_steps = T
|
||||
self.alphas = alphas
|
||||
self.alpha_bars = alpha_bars
|
||||
# cosine schedule(推荐)
|
||||
steps = torch.arange(T + 1, dtype=torch.float32)
|
||||
s = 0.1
|
||||
f = torch.cos(((steps / T) + s) / (1 + s) * math.pi * 0.5) ** 2
|
||||
alpha_bar = f / f[0]
|
||||
|
||||
self.alpha_bar = alpha_bar.to(device)
|
||||
self.sqrt_ab = torch.sqrt(self.alpha_bar)
|
||||
self.sqrt_one_minus_ab = torch.sqrt(1 - self.alpha_bar)
|
||||
|
||||
def q_sample(self, x0, t, noise):
|
||||
"""
|
||||
前向加噪
|
||||
"""
|
||||
alpha_bar = self.alpha_bars[t].reshape(-1, 1, 1, 1)
|
||||
res = noise * torch.sqrt(1 - alpha_bar) + torch.sqrt(alpha_bar) * x0
|
||||
return res
|
||||
return (
|
||||
self.sqrt_ab[t][:, None, None, None] * x0
|
||||
+ self.sqrt_one_minus_ab[t][:, None, None, None] * noise
|
||||
)
|
||||
|
||||
def sample(self, model, cond: torch.Tensor):
|
||||
x = torch.randn_like(cond).to(self.device)
|
||||
for t in range(self.n_steps - 1, -1, -1):
|
||||
x = self.sample_backward_step(x, t, cond, model)
|
||||
def sample(self, model, cond: torch.Tensor, steps=20):
|
||||
B = cond.shape[0]
|
||||
x = torch.randn_like(cond).to(cond.device)
|
||||
|
||||
step_size = self.T // steps
|
||||
|
||||
for i in reversed(range(0, self.T, step_size)):
|
||||
t = torch.full((B,), i, device=cond.device)
|
||||
|
||||
pred_noise = model(x, cond, t)
|
||||
|
||||
alpha = self.alpha_bar[i]
|
||||
alpha_prev = self.alpha_bar[max(i - step_size, 0)]
|
||||
|
||||
x0_pred = (x - torch.sqrt(1 - alpha) * pred_noise) / torch.sqrt(alpha)
|
||||
|
||||
x = (
|
||||
torch.sqrt(alpha_prev) * x0_pred
|
||||
+ torch.sqrt(1 - alpha_prev) * pred_noise
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
def sample_backward_step(self, x_t, t, cond, model):
|
||||
B = x_t.shape[0]
|
||||
t_tensor = torch.tensor([t] * B, dtype=torch.long).to(self.device)
|
||||
eps = model(x_t, cond, t_tensor)
|
||||
|
||||
if t == 0:
|
||||
noise = 0
|
||||
else:
|
||||
var = (1 - self.alpha_bars[t - 1]) / (1 - self.alpha_bars[t]) * self.betas[t]
|
||||
noise = torch.randn_like(x_t)
|
||||
noise *= torch.sqrt(var)
|
||||
|
||||
mean = (x_t -
|
||||
(1 - self.alphas[t]) / torch.sqrt(1 - self.alpha_bars[t]) *
|
||||
eps) / torch.sqrt(self.alphas[t])
|
||||
x_t = mean + noise
|
||||
|
||||
return x_t
|
||||
|
||||
if __name__ == '__main__':
|
||||
diff = Diffusion("cpu")
|
||||
print(diff.alphas)
|
||||
print(diff.alpha_bars)
|
||||
print(diff.sqrt_one_minus_ab)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user