chore: 添加 epoch 之间的检查

This commit is contained in:
unanmed 2026-04-25 15:28:29 +08:00
parent 3be014f3ad
commit 966d007721
3 changed files with 13 additions and 11 deletions

View File

@ -29,20 +29,20 @@ class GinkaHeatmapModel(nn.Module):
def forward(self, input: torch.Tensor, cond: torch.Tensor, t: torch.Tensor):
# input: [B, heatmap_dim, H, W] 噪声
# cond: [B, heatmap_dim, H, W] 点图
# t: [B, 1]
# t: [B]
input = self.input(input, t) # [B, d_model, H, W]
cond = self.cond(cond, t) # [B, d_model, H, W]
B, C, H, W = cond.shape
cond_tokens = cond.view(B, C, H * W).permute(0, 2, 1) # [B, H * W, d_model]
scale = torch.sigmoid(cond)
hidden = input * (1 + scale) + cond
B, C, H, W = input.shape
scale = torch.sigmoid(cond) # [B, d_model, H, W]
hidden = input * (1 + scale) + cond # [B, d_model, H, W]
hidden = hidden.view(B, C, H * W).permute(0, 2, 1) # [B, H * W, d_model]
hidden = hidden + self.pos_embedding
hidden = hidden + self.pos_embedding # [B, H * W, d_model]
hidden = self.transformer(hidden) # [B, H * W, d_model]
attn, _ = self.cross_attn(hidden, cond_tokens, cond_tokens)
hidden = hidden + attn
cond_tokens = cond.view(B, C, H * W).permute(0, 2, 1) # [B, H * W, d_model]
attn, _ = self.cross_attn(hidden, cond_tokens, cond_tokens) # [B, H * W, d_model]
hidden = hidden + attn # [B, H * W, d_model]
output = self.output_fc(hidden) # [B, H * W, heatmap_dim]
return output.view(B, H, W, self.heatmap_dim).permute(0, 3, 1, 2)
return output.view(B, H, W, self.heatmap_dim).permute(0, 3, 1, 2) # [B, heatmap_dim, H, W]
if __name__ == "__main__":
device = torch.device("cpu")

View File

@ -185,7 +185,7 @@ def train():
pred_noise = model(x_t, cond_heatmap, t)
loss = F.l1_loss(pred_noise, noise)
loss = F.mse_loss(pred_noise, noise)
val_loss_total += loss.detach()
@ -238,10 +238,11 @@ def get_nms_sampling_count():
def full_generate(heatmap, maskGIT, cond_heatmap: torch.Tensor, diffusion: Diffusion):
fake_heatmap_cond = diffusion.sample(heatmap, cond_heatmap)
fake_heatmap_uncond = diffusion.sample(heatmap, torch.zeros_like(cond_heatmap))
fake_heatmap = fake_heatmap_uncond + W * (fake_heatmap_uncond - fake_heatmap_cond)
fake_heatmap = fake_heatmap_uncond + W * (fake_heatmap_uncond - fake_heatmap_cond) # [B, C, H, W]
return maskGIT_generate(maskGIT, cond_heatmap.shape[0], fake_heatmap)
def maskGIT_generate(maskGIT, B: int, heatmap: torch.Tensor):
# heatmap: [B, C, H, W]
map = torch.full((B, MAP_H * MAP_W), MASK_TOKEN).to(device)
for i in range(GENERATE_STEP):
# 1. 预测

View File

@ -349,6 +349,7 @@ def train():
pred_noise_for_joint = model(x_t, cond_heatmap, t)
generated_heatmap = predict_x0(diffusion, x_t, pred_noise_for_joint, t)
print(torch.mean(generated_heatmap), torch.std(generated_heatmap), generated_heatmap.shape)
maskgit_loss = maskgit_joint_loss(maskgit, generated_heatmap, target_map)
loss = diffusion_loss + CE_WEIGHT * maskgit_loss