From cbbe31244485c60fcdfab017bbc23a614703d23d Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Wed, 8 Apr 2026 13:04:27 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/heatmap/diffusion.py | 4 ++-- ginka/train_heatmap.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ginka/heatmap/diffusion.py b/ginka/heatmap/diffusion.py index fc8e5c8..3705c99 100644 --- a/ginka/heatmap/diffusion.py +++ b/ginka/heatmap/diffusion.py @@ -2,7 +2,7 @@ import math import torch class Diffusion: - def __init__(self, device, T=100, min_beta=0.0001, max_beta=0.02): + def __init__(self, device, T=100, min_beta=0.0001, max_beta=0.01): self.T = T self.device = device @@ -29,7 +29,7 @@ class Diffusion: def sample(self, model, cond: torch.Tensor): x = torch.randn_like(cond).to(self.device) for t in range(self.n_steps - 1, -1, -1): - x = self.sample_backward_step(x, t, model) + x = self.sample_backward_step(x, t, cond, model) return x def sample_backward_step(self, x_t, t, cond, model): diff --git a/ginka/train_heatmap.py b/ginka/train_heatmap.py index 46a213f..5318dc3 100644 --- a/ginka/train_heatmap.py +++ b/ginka/train_heatmap.py @@ -132,6 +132,8 @@ def train(): target_heatmap = batch["target_heatmap"].to(device) B, C, H, W = target_heatmap.shape + optimizer.zero_grad() + t = torch.randint(1, T_DIFFUSION, [B], device=device) noise = torch.randn_like(target_heatmap) @@ -145,7 +147,6 @@ def train(): loss = F.mse_loss(pred_noise, noise) - optimizer.zero_grad() loss.backward() optimizer.step() loss_total += loss.detach()