mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
chore: 添加 epoch 之间的检查
This commit is contained in:
parent
3be014f3ad
commit
966d007721
@ -29,20 +29,20 @@ class GinkaHeatmapModel(nn.Module):
|
|||||||
def forward(self, input: torch.Tensor, cond: torch.Tensor, t: torch.Tensor):
|
def forward(self, input: torch.Tensor, cond: torch.Tensor, t: torch.Tensor):
|
||||||
# input: [B, heatmap_dim, H, W] 噪声
|
# input: [B, heatmap_dim, H, W] 噪声
|
||||||
# cond: [B, heatmap_dim, H, W] 点图
|
# cond: [B, heatmap_dim, H, W] 点图
|
||||||
# t: [B, 1]
|
# t: [B]
|
||||||
input = self.input(input, t) # [B, d_model, H, W]
|
input = self.input(input, t) # [B, d_model, H, W]
|
||||||
cond = self.cond(cond, t) # [B, d_model, H, W]
|
cond = self.cond(cond, t) # [B, d_model, H, W]
|
||||||
B, C, H, W = cond.shape
|
B, C, H, W = input.shape
|
||||||
cond_tokens = cond.view(B, C, H * W).permute(0, 2, 1) # [B, H * W, d_model]
|
scale = torch.sigmoid(cond) # [B, d_model, H, W]
|
||||||
scale = torch.sigmoid(cond)
|
hidden = input * (1 + scale) + cond # [B, d_model, H, W]
|
||||||
hidden = input * (1 + scale) + cond
|
|
||||||
hidden = hidden.view(B, C, H * W).permute(0, 2, 1) # [B, H * W, d_model]
|
hidden = hidden.view(B, C, H * W).permute(0, 2, 1) # [B, H * W, d_model]
|
||||||
hidden = hidden + self.pos_embedding
|
hidden = hidden + self.pos_embedding # [B, H * W, d_model]
|
||||||
hidden = self.transformer(hidden) # [B, H * W, d_model]
|
hidden = self.transformer(hidden) # [B, H * W, d_model]
|
||||||
attn, _ = self.cross_attn(hidden, cond_tokens, cond_tokens)
|
cond_tokens = cond.view(B, C, H * W).permute(0, 2, 1) # [B, H * W, d_model]
|
||||||
hidden = hidden + attn
|
attn, _ = self.cross_attn(hidden, cond_tokens, cond_tokens) # [B, H * W, d_model]
|
||||||
|
hidden = hidden + attn # [B, H * W, d_model]
|
||||||
output = self.output_fc(hidden) # [B, H * W, heatmap_dim]
|
output = self.output_fc(hidden) # [B, H * W, heatmap_dim]
|
||||||
return output.view(B, H, W, self.heatmap_dim).permute(0, 3, 1, 2)
|
return output.view(B, H, W, self.heatmap_dim).permute(0, 3, 1, 2) # [B, heatmap_dim, H, W]
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
|||||||
@ -185,7 +185,7 @@ def train():
|
|||||||
|
|
||||||
pred_noise = model(x_t, cond_heatmap, t)
|
pred_noise = model(x_t, cond_heatmap, t)
|
||||||
|
|
||||||
loss = F.l1_loss(pred_noise, noise)
|
loss = F.mse_loss(pred_noise, noise)
|
||||||
|
|
||||||
val_loss_total += loss.detach()
|
val_loss_total += loss.detach()
|
||||||
|
|
||||||
@ -238,10 +238,11 @@ def get_nms_sampling_count():
|
|||||||
def full_generate(heatmap, maskGIT, cond_heatmap: torch.Tensor, diffusion: Diffusion):
|
def full_generate(heatmap, maskGIT, cond_heatmap: torch.Tensor, diffusion: Diffusion):
|
||||||
fake_heatmap_cond = diffusion.sample(heatmap, cond_heatmap)
|
fake_heatmap_cond = diffusion.sample(heatmap, cond_heatmap)
|
||||||
fake_heatmap_uncond = diffusion.sample(heatmap, torch.zeros_like(cond_heatmap))
|
fake_heatmap_uncond = diffusion.sample(heatmap, torch.zeros_like(cond_heatmap))
|
||||||
fake_heatmap = fake_heatmap_uncond + W * (fake_heatmap_uncond - fake_heatmap_cond)
|
fake_heatmap = fake_heatmap_uncond + W * (fake_heatmap_uncond - fake_heatmap_cond) # [B, C, H, W]
|
||||||
return maskGIT_generate(maskGIT, cond_heatmap.shape[0], fake_heatmap)
|
return maskGIT_generate(maskGIT, cond_heatmap.shape[0], fake_heatmap)
|
||||||
|
|
||||||
def maskGIT_generate(maskGIT, B: int, heatmap: torch.Tensor):
|
def maskGIT_generate(maskGIT, B: int, heatmap: torch.Tensor):
|
||||||
|
# heatmap: [B, C, H, W]
|
||||||
map = torch.full((B, MAP_H * MAP_W), MASK_TOKEN).to(device)
|
map = torch.full((B, MAP_H * MAP_W), MASK_TOKEN).to(device)
|
||||||
for i in range(GENERATE_STEP):
|
for i in range(GENERATE_STEP):
|
||||||
# 1. 预测
|
# 1. 预测
|
||||||
|
|||||||
@ -349,6 +349,7 @@ def train():
|
|||||||
pred_noise_for_joint = model(x_t, cond_heatmap, t)
|
pred_noise_for_joint = model(x_t, cond_heatmap, t)
|
||||||
|
|
||||||
generated_heatmap = predict_x0(diffusion, x_t, pred_noise_for_joint, t)
|
generated_heatmap = predict_x0(diffusion, x_t, pred_noise_for_joint, t)
|
||||||
|
print(torch.mean(generated_heatmap), torch.std(generated_heatmap), generated_heatmap.shape)
|
||||||
maskgit_loss = maskgit_joint_loss(maskgit, generated_heatmap, target_map)
|
maskgit_loss = maskgit_joint_loss(maskgit, generated_heatmap, target_map)
|
||||||
|
|
||||||
loss = diffusion_loss + CE_WEIGHT * maskgit_loss
|
loss = diffusion_loss + CE_WEIGHT * maskgit_loss
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user