mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
64 lines
2.0 KiB
Python
64 lines
2.0 KiB
Python
import math
|
||
import torch
|
||
|
||
class Diffusion:
|
||
def __init__(self, device, T=100, noise_scale=0.5):
|
||
self.T = T
|
||
self.device = device
|
||
self.noise_scale = noise_scale
|
||
|
||
# cosine schedule(推荐)
|
||
steps = torch.arange(T + 1, dtype=torch.float32)
|
||
s = 0.1
|
||
f = torch.cos(((steps / (T + 1)) + s) / (1 + s) * math.pi * 0.5) ** 2
|
||
alpha_bar = f / f[0]
|
||
|
||
self.alpha_bar = alpha_bar.to(device)
|
||
self.sqrt_ab = torch.sqrt(self.alpha_bar)
|
||
self.sqrt_one_minus_ab = torch.sqrt(1 - self.alpha_bar)
|
||
|
||
def q_sample(self, x0, t, noise):
|
||
"""
|
||
前向加噪:x_t = sqrt(αbar_t) * x0 + sqrt(1-αbar_t) * noise_scale * ε
|
||
noise_scale 降低噪声功率,使信号不被淹没
|
||
"""
|
||
return (
|
||
self.sqrt_ab[t][:, None, None, None] * x0
|
||
+ self.sqrt_one_minus_ab[t][:, None, None, None] * noise * self.noise_scale
|
||
)
|
||
|
||
def sample(self, model, cond: torch.Tensor, steps=20):
|
||
"""
|
||
DDIM 风格逆向采样,模型预测 x_0
|
||
x_{t-1} = sqrt(αbar_{t-1}) * x0_pred
|
||
+ sqrt(1-αbar_{t-1}) / sqrt(1-αbar_t) * (x_t - sqrt(αbar_t) * x0_pred)
|
||
"""
|
||
B = cond.shape[0]
|
||
# 初始噪声与前向过程保持一致的噪声功率
|
||
x = torch.randn_like(cond).to(cond.device) * self.noise_scale
|
||
|
||
step_size = self.T // steps
|
||
|
||
for i in reversed(range(0, self.T, step_size)):
|
||
t = torch.full((B,), i, device=cond.device)
|
||
|
||
# 模型直接预测 x_0
|
||
x0_pred = model(x, cond, t)
|
||
|
||
alpha = self.alpha_bar[i]
|
||
alpha_prev = self.alpha_bar[max(i - step_size, 0)]
|
||
|
||
# DDIM x0-prediction 更新
|
||
direction = (
|
||
torch.sqrt(1 - alpha_prev) / torch.sqrt(1 - alpha)
|
||
) * (x - torch.sqrt(alpha) * x0_pred)
|
||
|
||
x = torch.sqrt(alpha_prev) * x0_pred + direction
|
||
|
||
return x
|
||
|
||
if __name__ == '__main__':
|
||
diff = Diffusion("cpu")
|
||
print(diff.sqrt_one_minus_ab)
|
||
print(diff.sqrt_ab)
|