From 8672c52ff53e16858e3ddfeeadb5b2818aa5e88f Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Wed, 11 Mar 2026 12:30:44 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E5=8F=96=E6=B6=88=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_maskGIT.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/ginka/train_maskGIT.py b/ginka/train_maskGIT.py index 2f3049d..e3b2473 100644 --- a/ginka/train_maskGIT.py +++ b/ginka/train_maskGIT.py @@ -112,30 +112,30 @@ def train(): for epoch in tqdm(range(args.epochs), desc="VAE Training", disable=disable_tqdm): loss_total = torch.Tensor([0]).to(device) - # for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): - # target_map = batch["target_map"].to(device) - # cond = batch["val_cond"].to(device) - # B, H, W = target_map.shape - # target_map = target_map.view(B, H * W) + for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): + target_map = batch["target_map"].to(device) + cond = batch["val_cond"].to(device) + B, H, W = target_map.shape + target_map = target_map.view(B, H * W) - # # 1. 随机采样掩码比例 r (遵循余弦调度效果更好) - # r = torch.rand(B).to(device) - # r = torch.cos(r * math.pi / 2).unsqueeze(1) # 产生更多高掩码比例的样本 + # 1. 随机采样掩码比例 r (遵循余弦调度效果更好) + r = torch.rand(B).to(device) + r = torch.cos(r * math.pi / 2).unsqueeze(1) # 产生更多高掩码比例的样本 - # # 2. 生成掩码矩阵 - # masks = torch.rand(target_map.shape).to(device) < r - # masked_input = target_map.clone() - # masked_input[masks] = MASK_TOKEN # 填充为 [MASK] 标记 + # 2. 生成掩码矩阵 + masks = torch.rand(target_map.shape).to(device) < r + masked_input = target_map.clone() + masked_input[masks] = MASK_TOKEN # 填充为 [MASK] 标记 - # logits = model(masked_input, cond) + logits = model(masked_input, cond) - # loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=1) - # loss = (loss * masks).sum() / (masks.sum() + 1e-6) + loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=1) + loss = (loss * masks).sum() / (masks.sum() + 1e-6) - # optimizer.zero_grad() - # loss.backward() - # optimizer.step() - # loss_total += loss.detach() + optimizer.zero_grad() + loss.backward() + optimizer.step() + loss_total += loss.detach() scheduler.step()