mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
chore: 添加 epoch 之间的检查
This commit is contained in:
parent
3be014f3ad
commit
966d007721
@ -29,20 +29,20 @@ class GinkaHeatmapModel(nn.Module):
|
||||
def forward(self, input: torch.Tensor, cond: torch.Tensor, t: torch.Tensor):
|
||||
# input: [B, heatmap_dim, H, W] 噪声
|
||||
# cond: [B, heatmap_dim, H, W] 点图
|
||||
# t: [B, 1]
|
||||
# t: [B]
|
||||
input = self.input(input, t) # [B, d_model, H, W]
|
||||
cond = self.cond(cond, t) # [B, d_model, H, W]
|
||||
B, C, H, W = cond.shape
|
||||
cond_tokens = cond.view(B, C, H * W).permute(0, 2, 1) # [B, H * W, d_model]
|
||||
scale = torch.sigmoid(cond)
|
||||
hidden = input * (1 + scale) + cond
|
||||
B, C, H, W = input.shape
|
||||
scale = torch.sigmoid(cond) # [B, d_model, H, W]
|
||||
hidden = input * (1 + scale) + cond # [B, d_model, H, W]
|
||||
hidden = hidden.view(B, C, H * W).permute(0, 2, 1) # [B, H * W, d_model]
|
||||
hidden = hidden + self.pos_embedding
|
||||
hidden = hidden + self.pos_embedding # [B, H * W, d_model]
|
||||
hidden = self.transformer(hidden) # [B, H * W, d_model]
|
||||
attn, _ = self.cross_attn(hidden, cond_tokens, cond_tokens)
|
||||
hidden = hidden + attn
|
||||
cond_tokens = cond.view(B, C, H * W).permute(0, 2, 1) # [B, H * W, d_model]
|
||||
attn, _ = self.cross_attn(hidden, cond_tokens, cond_tokens) # [B, H * W, d_model]
|
||||
hidden = hidden + attn # [B, H * W, d_model]
|
||||
output = self.output_fc(hidden) # [B, H * W, heatmap_dim]
|
||||
return output.view(B, H, W, self.heatmap_dim).permute(0, 3, 1, 2)
|
||||
return output.view(B, H, W, self.heatmap_dim).permute(0, 3, 1, 2) # [B, heatmap_dim, H, W]
|
||||
|
||||
if __name__ == "__main__":
|
||||
device = torch.device("cpu")
|
||||
|
||||
@ -185,7 +185,7 @@ def train():
|
||||
|
||||
pred_noise = model(x_t, cond_heatmap, t)
|
||||
|
||||
loss = F.l1_loss(pred_noise, noise)
|
||||
loss = F.mse_loss(pred_noise, noise)
|
||||
|
||||
val_loss_total += loss.detach()
|
||||
|
||||
@ -238,10 +238,11 @@ def get_nms_sampling_count():
|
||||
def full_generate(heatmap, maskGIT, cond_heatmap: torch.Tensor, diffusion: Diffusion):
|
||||
fake_heatmap_cond = diffusion.sample(heatmap, cond_heatmap)
|
||||
fake_heatmap_uncond = diffusion.sample(heatmap, torch.zeros_like(cond_heatmap))
|
||||
fake_heatmap = fake_heatmap_uncond + W * (fake_heatmap_uncond - fake_heatmap_cond)
|
||||
fake_heatmap = fake_heatmap_uncond + W * (fake_heatmap_uncond - fake_heatmap_cond) # [B, C, H, W]
|
||||
return maskGIT_generate(maskGIT, cond_heatmap.shape[0], fake_heatmap)
|
||||
|
||||
def maskGIT_generate(maskGIT, B: int, heatmap: torch.Tensor):
|
||||
# heatmap: [B, C, H, W]
|
||||
map = torch.full((B, MAP_H * MAP_W), MASK_TOKEN).to(device)
|
||||
for i in range(GENERATE_STEP):
|
||||
# 1. 预测
|
||||
|
||||
@ -349,6 +349,7 @@ def train():
|
||||
pred_noise_for_joint = model(x_t, cond_heatmap, t)
|
||||
|
||||
generated_heatmap = predict_x0(diffusion, x_t, pred_noise_for_joint, t)
|
||||
print(torch.mean(generated_heatmap), torch.std(generated_heatmap), generated_heatmap.shape)
|
||||
maskgit_loss = maskgit_joint_loss(maskgit, generated_heatmap, target_map)
|
||||
|
||||
loss = diffusion_loss + CE_WEIGHT * maskgit_loss
|
||||
|
||||
Loading…
Reference in New Issue
Block a user