diff --git a/ginka/heatmap/diffusion.py b/ginka/heatmap/diffusion.py index fc8e5c8..3705c99 100644 --- a/ginka/heatmap/diffusion.py +++ b/ginka/heatmap/diffusion.py @@ -2,7 +2,7 @@ import math import torch class Diffusion: - def __init__(self, device, T=100, min_beta=0.0001, max_beta=0.02): + def __init__(self, device, T=100, min_beta=0.0001, max_beta=0.01): self.T = T self.device = device @@ -29,7 +29,7 @@ class Diffusion: def sample(self, model, cond: torch.Tensor): x = torch.randn_like(cond).to(self.device) for t in range(self.n_steps - 1, -1, -1): - x = self.sample_backward_step(x, t, model) + x = self.sample_backward_step(x, t, cond, model) return x def sample_backward_step(self, x_t, t, cond, model): diff --git a/ginka/train_heatmap.py b/ginka/train_heatmap.py index 46a213f..5318dc3 100644 --- a/ginka/train_heatmap.py +++ b/ginka/train_heatmap.py @@ -132,6 +132,8 @@ def train(): target_heatmap = batch["target_heatmap"].to(device) B, C, H, W = target_heatmap.shape + optimizer.zero_grad() + t = torch.randint(1, T_DIFFUSION, [B], device=device) noise = torch.randn_like(target_heatmap) @@ -145,7 +147,6 @@ def train(): loss = F.mse_loss(pred_noise, noise) - optimizer.zero_grad() loss.backward() optimizer.step() loss_total += loss.detach()