diff --git a/ginka/heatmap/model.py b/ginka/heatmap/model.py index 618c1e9..e2ee83e 100644 --- a/ginka/heatmap/model.py +++ b/ginka/heatmap/model.py @@ -29,20 +29,20 @@ class GinkaHeatmapModel(nn.Module): def forward(self, input: torch.Tensor, cond: torch.Tensor, t: torch.Tensor): # input: [B, heatmap_dim, H, W] 噪声 # cond: [B, heatmap_dim, H, W] 点图 - # t: [B, 1] + # t: [B] input = self.input(input, t) # [B, d_model, H, W] cond = self.cond(cond, t) # [B, d_model, H, W] - B, C, H, W = cond.shape - cond_tokens = cond.view(B, C, H * W).permute(0, 2, 1) # [B, H * W, d_model] - scale = torch.sigmoid(cond) - hidden = input * (1 + scale) + cond + B, C, H, W = input.shape + scale = torch.sigmoid(cond) # [B, d_model, H, W] + hidden = input * (1 + scale) + cond # [B, d_model, H, W] hidden = hidden.view(B, C, H * W).permute(0, 2, 1) # [B, H * W, d_model] - hidden = hidden + self.pos_embedding + hidden = hidden + self.pos_embedding # [B, H * W, d_model] hidden = self.transformer(hidden) # [B, H * W, d_model] - attn, _ = self.cross_attn(hidden, cond_tokens, cond_tokens) - hidden = hidden + attn + cond_tokens = cond.view(B, C, H * W).permute(0, 2, 1) # [B, H * W, d_model] + attn, _ = self.cross_attn(hidden, cond_tokens, cond_tokens) # [B, H * W, d_model] + hidden = hidden + attn # [B, H * W, d_model] output = self.output_fc(hidden) # [B, H * W, heatmap_dim] - return output.view(B, H, W, self.heatmap_dim).permute(0, 3, 1, 2) + return output.view(B, H, W, self.heatmap_dim).permute(0, 3, 1, 2) # [B, heatmap_dim, H, W] if __name__ == "__main__": device = torch.device("cpu") diff --git a/ginka/train_heatmap.py b/ginka/train_heatmap.py index 5c0d836..edec34b 100644 --- a/ginka/train_heatmap.py +++ b/ginka/train_heatmap.py @@ -185,7 +185,7 @@ def train(): pred_noise = model(x_t, cond_heatmap, t) - loss = F.l1_loss(pred_noise, noise) + loss = F.mse_loss(pred_noise, noise) val_loss_total += loss.detach() @@ -238,10 +238,11 @@ def get_nms_sampling_count(): def full_generate(heatmap, maskGIT, cond_heatmap: torch.Tensor, diffusion: Diffusion): fake_heatmap_cond = diffusion.sample(heatmap, cond_heatmap) fake_heatmap_uncond = diffusion.sample(heatmap, torch.zeros_like(cond_heatmap)) - fake_heatmap = fake_heatmap_uncond + W * (fake_heatmap_uncond - fake_heatmap_cond) + fake_heatmap = fake_heatmap_uncond + W * (fake_heatmap_uncond - fake_heatmap_cond) # [B, C, H, W] return maskGIT_generate(maskGIT, cond_heatmap.shape[0], fake_heatmap) def maskGIT_generate(maskGIT, B: int, heatmap: torch.Tensor): + # heatmap: [B, C, H, W] map = torch.full((B, MAP_H * MAP_W), MASK_TOKEN).to(device) for i in range(GENERATE_STEP): # 1. 预测 diff --git a/ginka/train_joint.py b/ginka/train_joint.py index a1c9db4..34f01e7 100644 --- a/ginka/train_joint.py +++ b/ginka/train_joint.py @@ -349,6 +349,7 @@ def train(): pred_noise_for_joint = model(x_t, cond_heatmap, t) generated_heatmap = predict_x0(diffusion, x_t, pred_noise_for_joint, t) + print(torch.mean(generated_heatmap), torch.std(generated_heatmap), generated_heatmap.shape) maskgit_loss = maskgit_joint_loss(maskgit, generated_heatmap, target_map) loss = diffusion_loss + CE_WEIGHT * maskgit_loss