chore: 取消注释

This commit is contained in:
unanmed 2026-03-11 12:30:44 +08:00
parent 954c0109b9
commit 8672c52ff5

View File

@ -112,30 +112,30 @@ def train():
for epoch in tqdm(range(args.epochs), desc="VAE Training", disable=disable_tqdm):
loss_total = torch.Tensor([0]).to(device)
# for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
# target_map = batch["target_map"].to(device)
# cond = batch["val_cond"].to(device)
# B, H, W = target_map.shape
# target_map = target_map.view(B, H * W)
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
target_map = batch["target_map"].to(device)
cond = batch["val_cond"].to(device)
B, H, W = target_map.shape
target_map = target_map.view(B, H * W)
# # 1. 随机采样掩码比例 r (遵循余弦调度效果更好)
# r = torch.rand(B).to(device)
# r = torch.cos(r * math.pi / 2).unsqueeze(1) # 产生更多高掩码比例的样本
# 1. 随机采样掩码比例 r (遵循余弦调度效果更好)
r = torch.rand(B).to(device)
r = torch.cos(r * math.pi / 2).unsqueeze(1) # 产生更多高掩码比例的样本
# # 2. 生成掩码矩阵
# masks = torch.rand(target_map.shape).to(device) < r
# masked_input = target_map.clone()
# masked_input[masks] = MASK_TOKEN # 填充为 [MASK] 标记
# 2. 生成掩码矩阵
masks = torch.rand(target_map.shape).to(device) < r
masked_input = target_map.clone()
masked_input[masks] = MASK_TOKEN # 填充为 [MASK] 标记
# logits = model(masked_input, cond)
logits = model(masked_input, cond)
# loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=1)
# loss = (loss * masks).sum() / (masks.sum() + 1e-6)
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=1)
loss = (loss * masks).sum() / (masks.sum() + 1e-6)
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
# loss_total += loss.detach()
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_total += loss.detach()
scheduler.step()