fix: mask

This commit is contained in:
unanmed 2026-03-11 12:29:33 +08:00
parent fa72e18c1c
commit 954c0109b9
2 changed files with 28 additions and 27 deletions

View File

@ -51,6 +51,7 @@ VAL_BATCH_DIVIDER = 16
NUM_CLASSES = 16
MASK_TOKEN = 15
GENERATE_STEP = 8
MAP_SIZE = 13 * 13
device = torch.device(
"cuda:1" if torch.cuda.is_available()
@ -111,30 +112,30 @@ def train():
for epoch in tqdm(range(args.epochs), desc="VAE Training", disable=disable_tqdm):
loss_total = torch.Tensor([0]).to(device)
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
target_map = batch["target_map"].to(device)
cond = batch["val_cond"].to(device)
B, H, W = target_map.shape
target_map = target_map.view(B, H * W)
# for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
# target_map = batch["target_map"].to(device)
# cond = batch["val_cond"].to(device)
# B, H, W = target_map.shape
# target_map = target_map.view(B, H * W)
# 1. 随机采样掩码比例 r (遵循余弦调度效果更好)
r = torch.rand(B).to(device)
r = torch.cos(r * math.pi / 2).unsqueeze(1) # 产生更多高掩码比例的样本
# # 1. 随机采样掩码比例 r (遵循余弦调度效果更好)
# r = torch.rand(B).to(device)
# r = torch.cos(r * math.pi / 2).unsqueeze(1) # 产生更多高掩码比例的样本
# 2. 生成掩码矩阵
masks = torch.rand(target_map.shape).to(device) < r
masked_input = target_map.clone()
masked_input[masks] = MASK_TOKEN # 填充为 [MASK] 标记
# # 2. 生成掩码矩阵
# masks = torch.rand(target_map.shape).to(device) < r
# masked_input = target_map.clone()
# masked_input[masks] = MASK_TOKEN # 填充为 [MASK] 标记
logits = model(masked_input, cond)
# logits = model(masked_input, cond)
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none')
loss = (loss * masks).sum() / (masks.sum() + 1e-6)
# loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=1)
# loss = (loss * masks).sum() / (masks.sum() + 1e-6)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_total += loss.detach()
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
# loss_total += loss.detach()
scheduler.step()
@ -178,7 +179,7 @@ def train():
logits = model(masked_input, cond)
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none')
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1)
loss = (loss * masks.view(-1)).sum() / (masks.sum() + 1e-6)
val_loss_total += loss.detach()
@ -193,28 +194,28 @@ def train():
idx += 1
# 2. 从头完整生成
map = torch.full((1, 169), MASK_TOKEN).to(device)
for _ in range(GENERATE_STEP):
map = torch.full((1, MAP_SIZE), MASK_TOKEN).to(device)
for i in range(GENERATE_STEP):
# 1. 预测
logits = model(map, cond)
logits = model(map, cond) # [1, H * W, num_classes]
probs = F.softmax(logits, dim=-1)
# 2. 采样(为了多样性,这里可以使用概率采样而不是取最大值)
dist = torch.distributions.Categorical(probs)
sampled_tiles = dist.sample() # (1, 169)
sampled_tiles = dist.sample() # [1, H * W]
# 3. 计算置信度 (模型对采样结果的信心程度)
confidences = torch.gather(probs, -1, sampled_tiles.unsqueeze(-1)).squeeze(-1)
# 4. 决定本轮要固定多少个格子 (上凸函数逻辑)
ratio = math.cos((GENERATE_STEP) * math.pi / 2)
num_to_mask = int(ratio * 169)
ratio = math.cos(((i + 1) / GENERATE_STEP) * math.pi / 2)
num_to_mask = math.floor(ratio * MAP_SIZE)
# 5. 更新画布:保留置信度最高的部分,其余位置设回 MASK
# 注意:这里逻辑上通常是保留当前步预测中置信度最高的,并结合已有的非 mask 部分
if num_to_mask > 0:
_, mask_indices = torch.topk(confidences, k=num_to_mask, largest=False)
sampled_tiles.scatter_(1, mask_indices, MASK_TOKEN)
sampled_tiles = sampled_tiles.scatter(1, mask_indices, MASK_TOKEN)
map = sampled_tiles
if (map == MASK_TOKEN).sum() == 0:

BIN
tiles2/15.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 414 B