mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 调整 Diffusion 模型
This commit is contained in:
parent
c00a7dc5c1
commit
3e898dc5ba
@ -40,7 +40,7 @@ class HeatmapCond(nn.Module):
|
|||||||
|
|
||||||
def forward(self, heatmap: torch.Tensor, t: torch.Tensor):
|
def forward(self, heatmap: torch.Tensor, t: torch.Tensor):
|
||||||
# heatmap: [B, C, H, W]
|
# heatmap: [B, C, H, W]
|
||||||
# t: [B, 1]
|
# t: [B]
|
||||||
t_embed = self.time_embedding(t)
|
t_embed = self.time_embedding(t)
|
||||||
x = self.conv1(heatmap) + self.fc1(t_embed).unsqueeze(1).unsqueeze(1).permute(0, 3, 1, 2)
|
x = self.conv1(heatmap) + self.fc1(t_embed).unsqueeze(1).unsqueeze(1).permute(0, 3, 1, 2)
|
||||||
x = self.conv2(x) + self.fc2(t_embed).unsqueeze(1).unsqueeze(1).permute(0, 3, 1, 2)
|
x = self.conv2(x) + self.fc2(t_embed).unsqueeze(1).unsqueeze(1).permute(0, 3, 1, 2)
|
||||||
|
|||||||
@ -2,49 +2,56 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
class Diffusion:
|
class Diffusion:
|
||||||
def __init__(self, device, T=100):
|
def __init__(self, device, T=100, min_beta=0.0001, max_beta=0.02):
|
||||||
self.T = T
|
self.T = T
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
# cosine schedule(推荐)
|
betas = torch.linspace(min_beta, max_beta, T).to(device)
|
||||||
steps = torch.arange(T + 1, dtype=torch.float32)
|
alphas = 1 - betas
|
||||||
s = 0.008
|
alpha_bars = torch.empty_like(alphas)
|
||||||
f = torch.cos(((steps / T) + s) / (1 + s) * math.pi * 0.5) ** 2
|
product = 1
|
||||||
alpha_bar = f / f[0]
|
for i, alpha in enumerate(alphas):
|
||||||
|
product *= alpha
|
||||||
self.alpha_bar = alpha_bar.to(device)
|
alpha_bars[i] = product
|
||||||
self.sqrt_ab = torch.sqrt(self.alpha_bar)
|
self.betas = betas
|
||||||
self.sqrt_one_minus_ab = torch.sqrt(1 - self.alpha_bar)
|
self.n_steps = T
|
||||||
|
self.alphas = alphas
|
||||||
|
self.alpha_bars = alpha_bars
|
||||||
|
|
||||||
def q_sample(self, x0, t, noise):
|
def q_sample(self, x0, t, noise):
|
||||||
"""
|
"""
|
||||||
前向加噪
|
前向加噪
|
||||||
"""
|
"""
|
||||||
return (
|
alpha_bar = self.alpha_bars[t].reshape(-1, 1, 1, 1)
|
||||||
self.sqrt_ab[t][:, None, None, None] * x0
|
res = noise * torch.sqrt(1 - alpha_bar) + torch.sqrt(alpha_bar) * x0
|
||||||
+ self.sqrt_one_minus_ab[t][:, None, None, None] * noise
|
return res
|
||||||
)
|
|
||||||
|
|
||||||
def sample(self, model, cond: torch.Tensor, steps=20):
|
|
||||||
B = cond.shape[0]
|
|
||||||
x = torch.randn_like(cond).to(cond.device)
|
|
||||||
|
|
||||||
step_size = self.T // steps
|
|
||||||
|
|
||||||
for i in reversed(range(0, self.T, step_size)):
|
|
||||||
t = torch.full((B,), i, device=cond.device)
|
|
||||||
|
|
||||||
pred_noise = model(x, cond, t)
|
|
||||||
|
|
||||||
alpha = self.alpha_bar[i]
|
|
||||||
alpha_prev = self.alpha_bar[max(i - step_size, 0)]
|
|
||||||
|
|
||||||
x0_pred = (x - torch.sqrt(1 - alpha) * pred_noise) / torch.sqrt(alpha)
|
|
||||||
|
|
||||||
x = (
|
|
||||||
torch.sqrt(alpha_prev) * x0_pred
|
|
||||||
+ torch.sqrt(1 - alpha_prev) * pred_noise
|
|
||||||
)
|
|
||||||
|
|
||||||
|
def sample(self, model, cond: torch.Tensor):
|
||||||
|
x = torch.randn_like(cond).to(self.device)
|
||||||
|
for t in range(self.n_steps - 1, -1, -1):
|
||||||
|
x = self.sample_backward_step(x, t, model)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def sample_backward_step(self, x_t, t, cond, model):
|
||||||
|
B = x_t.shape[0]
|
||||||
|
t_tensor = torch.tensor([t] * B, dtype=torch.long).to(self.device)
|
||||||
|
eps = model(x_t, cond, t_tensor)
|
||||||
|
|
||||||
|
if t == 0:
|
||||||
|
noise = 0
|
||||||
|
else:
|
||||||
|
var = (1 - self.alpha_bars[t - 1]) / (1 - self.alpha_bars[t]) * self.betas[t]
|
||||||
|
noise = torch.randn_like(x_t)
|
||||||
|
noise *= torch.sqrt(var)
|
||||||
|
|
||||||
|
mean = (x_t -
|
||||||
|
(1 - self.alphas[t]) / torch.sqrt(1 - self.alpha_bars[t]) *
|
||||||
|
eps) / torch.sqrt(self.alphas[t])
|
||||||
|
x_t = mean + noise
|
||||||
|
|
||||||
|
return x_t
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
diff = Diffusion("cpu")
|
||||||
|
print(diff.alphas)
|
||||||
|
print(diff.alpha_bars)
|
||||||
|
|||||||
@ -16,8 +16,14 @@ class GinkaHeatmapModel(nn.Module):
|
|||||||
self.cond = HeatmapCond(T, embed_dim=embed_dim, heatmap_dim=heatmap_dim, output_dim=d_model)
|
self.cond = HeatmapCond(T, embed_dim=embed_dim, heatmap_dim=heatmap_dim, output_dim=d_model)
|
||||||
self.input = HeatmapCond(T, embed_dim=embed_dim, heatmap_dim=heatmap_dim, output_dim=d_model)
|
self.input = HeatmapCond(T, embed_dim=embed_dim, heatmap_dim=heatmap_dim, output_dim=d_model)
|
||||||
self.transformer = MaskGIT(d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers)
|
self.transformer = MaskGIT(d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers)
|
||||||
|
self.cross_attn = nn.MultiheadAttention(d_model, num_heads=nhead, batch_first=True)
|
||||||
self.output_fc = nn.Sequential(
|
self.output_fc = nn.Sequential(
|
||||||
nn.Linear(d_model, heatmap_dim)
|
nn.Linear(d_model, d_model // 2),
|
||||||
|
nn.LayerNorm(d_model // 2),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.GELU(),
|
||||||
|
|
||||||
|
nn.Linear(d_model // 2, heatmap_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor, cond: torch.Tensor, t: torch.Tensor):
|
def forward(self, input: torch.Tensor, cond: torch.Tensor, t: torch.Tensor):
|
||||||
@ -26,11 +32,15 @@ class GinkaHeatmapModel(nn.Module):
|
|||||||
# t: [B, 1]
|
# t: [B, 1]
|
||||||
input = self.input(input, t) # [B, d_model, H, W]
|
input = self.input(input, t) # [B, d_model, H, W]
|
||||||
cond = self.cond(cond, t) # [B, d_model, H, W]
|
cond = self.cond(cond, t) # [B, d_model, H, W]
|
||||||
hidden = input + cond
|
B, C, H, W = cond.shape
|
||||||
B, C, H, W = hidden.shape
|
cond_tokens = cond.view(B, C, H * W).permute(0, 2, 1) # [B, H * W, d_model]
|
||||||
|
scale = torch.sigmoid(cond)
|
||||||
|
hidden = input * (1 + scale) + cond
|
||||||
hidden = hidden.view(B, C, H * W).permute(0, 2, 1) # [B, H * W, d_model]
|
hidden = hidden.view(B, C, H * W).permute(0, 2, 1) # [B, H * W, d_model]
|
||||||
hidden = hidden + self.pos_embedding
|
hidden = hidden + self.pos_embedding
|
||||||
hidden = self.transformer(hidden) # [B, H * W, d_model]
|
hidden = self.transformer(hidden) # [B, H * W, d_model]
|
||||||
|
attn, _ = self.cross_attn(hidden, cond_tokens, cond_tokens)
|
||||||
|
hidden = hidden + attn
|
||||||
output = self.output_fc(hidden) # [B, H * W, heatmap_dim]
|
output = self.output_fc(hidden) # [B, H * W, heatmap_dim]
|
||||||
return output.view(B, H, W, self.heatmap_dim).permute(0, 3, 1, 2)
|
return output.view(B, H, W, self.heatmap_dim).permute(0, 3, 1, 2)
|
||||||
|
|
||||||
@ -39,7 +49,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
input = torch.randn(1, 9, 13, 13).to(device)
|
input = torch.randn(1, 9, 13, 13).to(device)
|
||||||
cond = torch.randint(0, 1, [1, 9, 13, 13]).to(device)
|
cond = torch.randint(0, 1, [1, 9, 13, 13]).to(device)
|
||||||
t = torch.randint(0, 100, [1, 1]).to(device)
|
t = torch.randint(0, 100, [1]).to(device)
|
||||||
|
|
||||||
# 初始化模型
|
# 初始化模型
|
||||||
model = GinkaHeatmapModel(heatmap_dim=9).to(device)
|
model = GinkaHeatmapModel(heatmap_dim=9).to(device)
|
||||||
|
|||||||
@ -49,6 +49,7 @@ T_DIFFUSION = 100
|
|||||||
MIN_MASK = 0
|
MIN_MASK = 0
|
||||||
MAX_MASK = 0.8
|
MAX_MASK = 0.8
|
||||||
NOISE_SAMPLING_K = [40, 15, 21, 8, 8, 4, 1, 2, 10]
|
NOISE_SAMPLING_K = [40, 15, 21, 8, 8, 4, 1, 2, 10]
|
||||||
|
W = 5 # CFG 参数
|
||||||
|
|
||||||
device = torch.device(
|
device = torch.device(
|
||||||
"cuda:1" if torch.cuda.is_available()
|
"cuda:1" if torch.cuda.is_available()
|
||||||
@ -131,11 +132,15 @@ def train():
|
|||||||
target_heatmap = batch["target_heatmap"].to(device)
|
target_heatmap = batch["target_heatmap"].to(device)
|
||||||
B, C, H, W = target_heatmap.shape
|
B, C, H, W = target_heatmap.shape
|
||||||
|
|
||||||
t = torch.randint(1, T_DIFFUSION, (B,), device=device)
|
t = torch.randint(1, T_DIFFUSION, [B], device=device)
|
||||||
noise = torch.randn_like(target_heatmap)
|
noise = torch.randn_like(target_heatmap)
|
||||||
|
|
||||||
x_t = diffusion.q_sample(target_heatmap, t, noise)
|
x_t = diffusion.q_sample(target_heatmap, t, noise)
|
||||||
|
|
||||||
|
# CFG 随机概率没有输入条件
|
||||||
|
if np.random.rand() < 0.2:
|
||||||
|
cond_heatmap = torch.zeros_like(cond_heatmap)
|
||||||
|
|
||||||
pred_noise = model(x_t, cond_heatmap, t)
|
pred_noise = model(x_t, cond_heatmap, t)
|
||||||
|
|
||||||
loss = F.mse_loss(pred_noise, noise)
|
loss = F.mse_loss(pred_noise, noise)
|
||||||
@ -185,8 +190,7 @@ def train():
|
|||||||
|
|
||||||
# 2. 从头完整生成,并使用训练好的 MaskGIT 生成地图
|
# 2. 从头完整生成,并使用训练好的 MaskGIT 生成地图
|
||||||
if args.use_maskgit:
|
if args.use_maskgit:
|
||||||
fake_heatmap = diffusion.sample(model, cond_heatmap)
|
map = full_generate(model, maskGIT, cond_heatmap, diffusion)
|
||||||
map = maskGIT_generate(maskGIT, B, fake_heatmap)
|
|
||||||
|
|
||||||
generated_img = matrix_to_image_cv(map.view(B, H, W)[0].cpu().numpy(), tile_dict)
|
generated_img = matrix_to_image_cv(map.view(B, H, W)[0].cpu().numpy(), tile_dict)
|
||||||
cv2.imwrite(f"result/final_img/{idx}.png", generated_img)
|
cv2.imwrite(f"result/final_img/{idx}.png", generated_img)
|
||||||
@ -199,8 +203,7 @@ def train():
|
|||||||
noise = generate_fractal_noise_2d((16, 16), (4, 4), 1)[0:MAP_H,0:MAP_W]
|
noise = generate_fractal_noise_2d((16, 16), (4, 4), 1)[0:MAP_H,0:MAP_W]
|
||||||
ar[0,c] = nms_sampling(noise, NOISE_SAMPLING_K[c])
|
ar[0,c] = nms_sampling(noise, NOISE_SAMPLING_K[c])
|
||||||
|
|
||||||
fake_heatmap = diffusion.sample(model, torch.FloatTensor(ar).to(device))
|
map = full_generate(model, maskGIT, torch.FloatTensor(ar).to(device), diffusion)
|
||||||
map = maskGIT_generate(maskGIT, B, fake_heatmap)
|
|
||||||
generated_img = matrix_to_image_cv(map.view(1, H, W)[0].cpu().numpy(), tile_dict)
|
generated_img = matrix_to_image_cv(map.view(1, H, W)[0].cpu().numpy(), tile_dict)
|
||||||
cv2.imwrite(f"result/final_img/g-{i}.png", generated_img)
|
cv2.imwrite(f"result/final_img/g-{i}.png", generated_img)
|
||||||
|
|
||||||
@ -215,6 +218,12 @@ def train():
|
|||||||
"model_state": maskGIT.state_dict(),
|
"model_state": maskGIT.state_dict(),
|
||||||
}, f"result/ginka_heatmap.pth")
|
}, f"result/ginka_heatmap.pth")
|
||||||
|
|
||||||
|
def full_generate(heatmap, maskGIT, cond_heatmap: torch.Tensor, diffusion: Diffusion):
|
||||||
|
fake_heatmap_cond = diffusion.sample(heatmap, cond_heatmap)
|
||||||
|
fake_heatmap_uncond = diffusion.sample(heatmap, torch.zeros_like(cond_heatmap))
|
||||||
|
fake_heatmap = fake_heatmap_uncond + W * (fake_heatmap_uncond - fake_heatmap_cond)
|
||||||
|
return maskGIT_generate(maskGIT, cond_heatmap.shape[0], fake_heatmap)
|
||||||
|
|
||||||
def maskGIT_generate(maskGIT, B: int, heatmap: torch.Tensor):
|
def maskGIT_generate(maskGIT, B: int, heatmap: torch.Tensor):
|
||||||
map = torch.full((B, MAP_H * MAP_W), MASK_TOKEN).to(device)
|
map = torch.full((B, MAP_H * MAP_W), MASK_TOKEN).to(device)
|
||||||
for i in range(GENERATE_STEP):
|
for i in range(GENERATE_STEP):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user