fix: 报错

This commit is contained in:
unanmed 2026-04-08 13:04:27 +08:00
parent 5df14a3f3f
commit cbbe312444
2 changed files with 4 additions and 3 deletions

View File

@ -2,7 +2,7 @@ import math
import torch import torch
class Diffusion: class Diffusion:
def __init__(self, device, T=100, min_beta=0.0001, max_beta=0.02): def __init__(self, device, T=100, min_beta=0.0001, max_beta=0.01):
self.T = T self.T = T
self.device = device self.device = device
@ -29,7 +29,7 @@ class Diffusion:
def sample(self, model, cond: torch.Tensor): def sample(self, model, cond: torch.Tensor):
x = torch.randn_like(cond).to(self.device) x = torch.randn_like(cond).to(self.device)
for t in range(self.n_steps - 1, -1, -1): for t in range(self.n_steps - 1, -1, -1):
x = self.sample_backward_step(x, t, model) x = self.sample_backward_step(x, t, cond, model)
return x return x
def sample_backward_step(self, x_t, t, cond, model): def sample_backward_step(self, x_t, t, cond, model):

View File

@ -132,6 +132,8 @@ def train():
target_heatmap = batch["target_heatmap"].to(device) target_heatmap = batch["target_heatmap"].to(device)
B, C, H, W = target_heatmap.shape B, C, H, W = target_heatmap.shape
optimizer.zero_grad()
t = torch.randint(1, T_DIFFUSION, [B], device=device) t = torch.randint(1, T_DIFFUSION, [B], device=device)
noise = torch.randn_like(target_heatmap) noise = torch.randn_like(target_heatmap)
@ -145,7 +147,6 @@ def train():
loss = F.mse_loss(pred_noise, noise) loss = F.mse_loss(pred_noise, noise)
optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
loss_total += loss.detach() loss_total += loss.detach()