refactor: heatmap 模型采用预测 x_0 而非噪声

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
unanmed 2026-04-25 17:09:48 +08:00
parent b471bb46eb
commit 1eda704986
3 changed files with 43 additions and 41 deletions

View File

@ -2,9 +2,10 @@ import math
import torch
class Diffusion:
def __init__(self, device, T=100):
def __init__(self, device, T=100, noise_scale=0.5):
self.T = T
self.device = device
self.noise_scale = noise_scale
# cosine schedule推荐
steps = torch.arange(T + 1, dtype=torch.float32)
@ -18,33 +19,41 @@ class Diffusion:
def q_sample(self, x0, t, noise):
"""
前向加噪
前向加噪x_t = sqrt(αbar_t) * x0 + sqrt(1-αbar_t) * noise_scale * ε
noise_scale 降低噪声功率使信号不被淹没
"""
return (
self.sqrt_ab[t][:, None, None, None] * x0
+ self.sqrt_one_minus_ab[t][:, None, None, None] * noise
+ self.sqrt_one_minus_ab[t][:, None, None, None] * noise * self.noise_scale
)
def sample(self, model, cond: torch.Tensor, steps=20):
"""
DDIM 风格逆向采样模型预测 x_0
x_{t-1} = sqrt(αbar_{t-1}) * x0_pred
+ sqrt(1-αbar_{t-1}) / sqrt(1-αbar_t) * (x_t - sqrt(αbar_t) * x0_pred)
"""
B = cond.shape[0]
x = torch.randn_like(cond).to(cond.device)
# 初始噪声与前向过程保持一致的噪声功率
x = torch.randn_like(cond).to(cond.device) * self.noise_scale
step_size = self.T // steps
for i in reversed(range(0, self.T, step_size)):
t = torch.full((B,), i, device=cond.device)
pred_noise = model(x, cond, t)
# 模型直接预测 x_0
x0_pred = model(x, cond, t)
alpha = self.alpha_bar[i]
alpha_prev = self.alpha_bar[max(i - step_size, 0)]
x0_pred = (x - torch.sqrt(1 - alpha) * pred_noise) / torch.sqrt(alpha)
# DDIM x0-prediction 更新
direction = (
torch.sqrt(1 - alpha_prev) / torch.sqrt(1 - alpha)
) * (x - torch.sqrt(alpha) * x0_pred)
x = (
torch.sqrt(alpha_prev) * x0_pred
+ torch.sqrt(1 - alpha_prev) * pred_noise
)
x = torch.sqrt(alpha_prev) * x0_pred + direction
return x

View File

@ -48,6 +48,7 @@ D_MODEL_DIFFUSION = 128
T_DIFFUSION = 100
MIN_MASK = 0
MAX_MASK = 1
NOISE_SCALE = 0.3
W = 5 # CFG 参数
device = torch.device(
@ -91,7 +92,7 @@ def train():
num_layers=NUM_LAYERS_DIFFUSION
).to(device)
diffusion = Diffusion(device)
diffusion = Diffusion(device, noise_scale=NOISE_SCALE)
dataset = GinkaHeatmapDataset(args.train, min_mask=MIN_MASK, max_mask=MAX_MASK)
dataset_val = GinkaHeatmapDataset(args.validate, min_mask=MIN_MASK, max_mask=MAX_MASK)
@ -129,7 +130,7 @@ def train():
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
cond_heatmap = batch["cond_heatmap"].to(device)
target_heatmap = batch["target_heatmap"].to(device) * 2 - 1
target_heatmap = batch["target_heatmap"].to(device)
B, C, H, W = target_heatmap.shape
optimizer.zero_grad()
@ -143,9 +144,10 @@ def train():
if np.random.rand() < 0.2:
cond_heatmap = torch.zeros_like(cond_heatmap)
pred_noise = model(x_t, cond_heatmap, t)
# 模型预测 x_0损失直接对齐热力图
pred_x0 = model(x_t, cond_heatmap, t)
loss = F.mse_loss(pred_noise, noise)
loss = F.mse_loss(pred_x0, target_heatmap)
loss.backward()
optimizer.step()
@ -175,7 +177,7 @@ def train():
for batch in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm):
# 1. 验证集验证
cond_heatmap = batch["cond_heatmap"].to(device)
target_heatmap = batch["target_heatmap"].to(device) * 2 - 1
target_heatmap = batch["target_heatmap"].to(device)
B, C, H, W = target_heatmap.shape
t = torch.randint(1, T_DIFFUSION, [B], device=device)
@ -183,9 +185,9 @@ def train():
x_t = diffusion.q_sample(target_heatmap, t, noise)
pred_noise = model(x_t, cond_heatmap, t)
pred_x0 = model(x_t, cond_heatmap, t)
loss = F.mse_loss(pred_noise, noise)
loss = F.mse_loss(pred_x0, target_heatmap)
val_loss_total += loss.detach()
@ -236,8 +238,8 @@ def get_nms_sampling_count():
]
def full_generate(heatmap, maskGIT, cond_heatmap: torch.Tensor, diffusion: Diffusion):
fake_heatmap_cond = (diffusion.sample(heatmap, cond_heatmap) + 1) / 2
fake_heatmap_uncond = (diffusion.sample(heatmap, torch.zeros_like(cond_heatmap)) + 1) / 2
fake_heatmap_cond = diffusion.sample(heatmap, cond_heatmap)
fake_heatmap_uncond = diffusion.sample(heatmap, torch.zeros_like(cond_heatmap))
fake_heatmap = fake_heatmap_uncond + W * (fake_heatmap_uncond - fake_heatmap_cond) # [B, C, H, W]
return maskGIT_generate(maskGIT, cond_heatmap.shape[0], fake_heatmap)

View File

@ -46,6 +46,7 @@ D_MODEL_DIFFUSION = 128
T_DIFFUSION = 100
MIN_MASK = 0
MAX_MASK = 1
NOISE_SCALE = 0.3
# 验证预览配置
PREVIEW_CFG_WEIGHT = 5 # 预览生成时使用的 CFG 强度
@ -102,14 +103,6 @@ def freeze_module(module: torch.nn.Module):
parameter.requires_grad = False
def predict_x0(diffusion: Diffusion, x_t: torch.Tensor, pred_noise: torch.Tensor, t: torch.Tensor):
# 根据当前时刻的噪声预测还原 x0 热力图估计。
sqrt_ab = diffusion.sqrt_ab[t][:, None, None, None]
sqrt_one_minus_ab = diffusion.sqrt_one_minus_ab[t][:, None, None, None]
x0 = (x_t - sqrt_one_minus_ab * pred_noise) / sqrt_ab
return x0
def maskgit_joint_loss(maskgit, generated_heatmap: torch.Tensor, target_map: torch.Tensor):
# 用冻结的 MaskGIT 对 Diffusion 生成的热力图施加地图级监督。
batch_size, height, width = target_map.shape
@ -233,7 +226,7 @@ def validate(model, maskgit, diffusion, dataloader, ce_weight: float, tile_dict)
preview_idx = 0
for batch in tqdm(dataloader, desc="Validating", leave=False, disable=disable_tqdm):
cond_heatmap = batch["cond_heatmap"].to(device)
target_heatmap = batch["target_heatmap"].to(device) * 2 - 1
target_heatmap = batch["target_heatmap"].to(device)
target_map = batch["target_map"].to(device)
batch_size, _, map_height, map_width = target_heatmap.shape
@ -241,11 +234,10 @@ def validate(model, maskgit, diffusion, dataloader, ce_weight: float, tile_dict)
noise = torch.randn_like(target_heatmap)
x_t = diffusion.q_sample(target_heatmap, t, noise)
pred_noise = model(x_t, cond_heatmap, t)
diffusion_loss = F.mse_loss(pred_noise, noise)
pred_x0 = model(x_t, cond_heatmap, t)
diffusion_loss = F.mse_loss(pred_x0, target_heatmap)
generated_heatmap = (predict_x0(diffusion, x_t, pred_noise, t) + 1) / 2
maskgit_loss = maskgit_joint_loss(maskgit, generated_heatmap, target_map)
maskgit_loss = maskgit_joint_loss(maskgit, pred_x0, target_map)
loss = diffusion_loss + ce_weight * maskgit_loss
total_loss += loss.item()
@ -297,7 +289,7 @@ def train():
d_model=D_MODEL_DIFFUSION,
num_layers=NUM_LAYERS_DIFFUSION,
).to(device)
diffusion = Diffusion(device, T=T_DIFFUSION)
diffusion = Diffusion(device, T=T_DIFFUSION, noise_scale=NOISE_SCALE)
dataset = GinkaJointDataset(args.train, min_mask=MIN_MASK, max_mask=MAX_MASK)
dataset_val = GinkaJointDataset(args.validate, min_mask=MIN_MASK, max_mask=MAX_MASK)
@ -325,7 +317,7 @@ def train():
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
cond_heatmap = batch["cond_heatmap"].to(device)
target_heatmap = batch["target_heatmap"].to(device) * 2 - 1
target_heatmap = batch["target_heatmap"].to(device)
target_map = batch["target_map"].to(device)
batch_size = target_heatmap.shape[0]
@ -341,16 +333,15 @@ def train():
cond_for_diffusion = torch.zeros_like(cond_heatmap)
use_unconditional_branch = True
pred_noise = model(x_t, cond_for_diffusion, t)
diffusion_loss = F.mse_loss(pred_noise, noise)
pred_x0 = model(x_t, cond_for_diffusion, t)
diffusion_loss = F.mse_loss(pred_x0, target_heatmap)
pred_noise_for_joint = pred_noise
# 若使用无条件分支,重新对有条件输入预测以计算联合损失
pred_x0_for_joint = pred_x0
if use_unconditional_branch:
pred_noise_for_joint = model(x_t, cond_heatmap, t)
pred_x0_for_joint = model(x_t, cond_heatmap, t)
generated_heatmap = (predict_x0(diffusion, x_t, pred_noise_for_joint, t) + 1) / 2
print(torch.mean(generated_heatmap), torch.std(generated_heatmap), generated_heatmap.shape)
maskgit_loss = maskgit_joint_loss(maskgit, generated_heatmap, target_map)
maskgit_loss = maskgit_joint_loss(maskgit, pred_x0_for_joint, target_map)
loss = diffusion_loss + CE_WEIGHT * maskgit_loss
loss.backward()