mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
fix: 报错
This commit is contained in:
parent
5df14a3f3f
commit
cbbe312444
@ -2,7 +2,7 @@ import math
|
||||
import torch
|
||||
|
||||
class Diffusion:
|
||||
def __init__(self, device, T=100, min_beta=0.0001, max_beta=0.02):
|
||||
def __init__(self, device, T=100, min_beta=0.0001, max_beta=0.01):
|
||||
self.T = T
|
||||
self.device = device
|
||||
|
||||
@ -29,7 +29,7 @@ class Diffusion:
|
||||
def sample(self, model, cond: torch.Tensor):
|
||||
x = torch.randn_like(cond).to(self.device)
|
||||
for t in range(self.n_steps - 1, -1, -1):
|
||||
x = self.sample_backward_step(x, t, model)
|
||||
x = self.sample_backward_step(x, t, cond, model)
|
||||
return x
|
||||
|
||||
def sample_backward_step(self, x_t, t, cond, model):
|
||||
|
||||
@ -132,6 +132,8 @@ def train():
|
||||
target_heatmap = batch["target_heatmap"].to(device)
|
||||
B, C, H, W = target_heatmap.shape
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
t = torch.randint(1, T_DIFFUSION, [B], device=device)
|
||||
noise = torch.randn_like(target_heatmap)
|
||||
|
||||
@ -145,7 +147,6 @@ def train():
|
||||
|
||||
loss = F.mse_loss(pred_noise, noise)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
loss_total += loss.detach()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user