ginka-generator/ginka/heatmap/diffusion.py
unanmed 1eda704986 refactor: heatmap 模型采用预测 x_0 而非噪声
Co-authored-by: Copilot <copilot@github.com>
2026-04-25 17:09:48 +08:00

64 lines
2.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
import torch
class Diffusion:
def __init__(self, device, T=100, noise_scale=0.5):
self.T = T
self.device = device
self.noise_scale = noise_scale
# cosine schedule推荐
steps = torch.arange(T + 1, dtype=torch.float32)
s = 0.1
f = torch.cos(((steps / (T + 1)) + s) / (1 + s) * math.pi * 0.5) ** 2
alpha_bar = f / f[0]
self.alpha_bar = alpha_bar.to(device)
self.sqrt_ab = torch.sqrt(self.alpha_bar)
self.sqrt_one_minus_ab = torch.sqrt(1 - self.alpha_bar)
def q_sample(self, x0, t, noise):
"""
前向加噪x_t = sqrt(αbar_t) * x0 + sqrt(1-αbar_t) * noise_scale * ε
noise_scale 降低噪声功率,使信号不被淹没
"""
return (
self.sqrt_ab[t][:, None, None, None] * x0
+ self.sqrt_one_minus_ab[t][:, None, None, None] * noise * self.noise_scale
)
def sample(self, model, cond: torch.Tensor, steps=20):
"""
DDIM 风格逆向采样,模型预测 x_0
x_{t-1} = sqrt(αbar_{t-1}) * x0_pred
+ sqrt(1-αbar_{t-1}) / sqrt(1-αbar_t) * (x_t - sqrt(αbar_t) * x0_pred)
"""
B = cond.shape[0]
# 初始噪声与前向过程保持一致的噪声功率
x = torch.randn_like(cond).to(cond.device) * self.noise_scale
step_size = self.T // steps
for i in reversed(range(0, self.T, step_size)):
t = torch.full((B,), i, device=cond.device)
# 模型直接预测 x_0
x0_pred = model(x, cond, t)
alpha = self.alpha_bar[i]
alpha_prev = self.alpha_bar[max(i - step_size, 0)]
# DDIM x0-prediction 更新
direction = (
torch.sqrt(1 - alpha_prev) / torch.sqrt(1 - alpha)
) * (x - torch.sqrt(alpha) * x0_pred)
x = torch.sqrt(alpha_prev) * x0_pred + direction
return x
if __name__ == '__main__':
diff = Diffusion("cpu")
print(diff.sqrt_one_minus_ab)
print(diff.sqrt_ab)