feat: 采样时随机抛弃

This commit is contained in:
unanmed 2026-05-15 00:02:19 +08:00
parent 3df7d59575
commit dc3062bcee

View File

@ -180,11 +180,21 @@ def random_struct(device: torch.device) -> torch.Tensor:
def maskgit_sample(
model: torch.nn.Module, inp: torch.Tensor, z: torch.Tensor,
struct: torch.Tensor, steps: int, keep_fixed: bool = True
struct: torch.Tensor, steps: int, target_tiles: list[int] | None = None,
keep_fixed: bool = True
) -> np.ndarray:
# keep_fixed=True锁定输入中已有的非掩码位使上一阶段结构保持不变
# keep_fixed=False所有位置均可被模型自由重估适合探索更多样的生成结果
# target_tiles: 本阶段负责生成的图块 ID 列表None 表示接受所有类别stage1
# keep_fixed=True锁定输入中已有的非掩码/非空地位,使上一阶段结构保持不变
# keep_fixed=False结构位保留但每步结束后空地重新标为 MASK探索模式
#
# 有 target_tiles 时的核心逻辑:
# 每步只从预测中选出置信度最高的若干 target_tile 候选揭开,
# 其余已有结构(墙/门等非空地非掩码)原样保留,
# 空地与掩码保持为 MASK等待后续步骤继续填充。
current = inp.clone()
has_target = target_tiles is not None
if has_target:
target_tensor = torch.tensor(target_tiles, dtype=torch.long, device=inp.device)
# 迭代去掩码:每步根据置信度分数重新决定掩码位置
for step in range(steps):
@ -200,27 +210,62 @@ def maskgit_sample(
ratio = math.cos(((step + 1) / steps) * math.pi / 2)
num_to_mask = math.floor(ratio * MAP_SIZE)
if keep_fixed:
# 输入中已有的非掩码位(来自上一阶段)保持不变
fixed_mask = (current[0] != MASK_TOKEN)
sampled[0, fixed_mask] = current[0, fixed_mask]
confidences[0, fixed_mask] = 1.0
if num_to_mask > 0:
# 将置信度最低的位重新掩码,留待下一步重新预测
_, mask_indices = torch.topk(confidences[0], k=num_to_mask, largest=False)
sampled[0].scatter_(0, mask_indices, MASK_TOKEN)
current = sampled
if not has_target:
# stage1无 target 约束,仅锁定 fixed 位(若 keep_fixed
if keep_fixed:
fixed_mask = (current[0] != MASK_TOKEN)
sampled[0, fixed_mask] = current[0, fixed_mask]
confidences[0, fixed_mask] = 1.0
if num_to_mask > 0:
_, mask_indices = torch.topk(confidences[0], k=num_to_mask, largest=False)
sampled[0].scatter_(0, mask_indices, MASK_TOKEN)
current = sampled
else:
# 有 target_tiles基于当前 current 构建下一状态
# 结构位current 中非空地、非掩码的位置(来自上一阶段,始终保留)
struct_mask = (current[0] != MASK_TOKEN) & (current[0] != 0)
# 候选位sampled 为目标图块且不覆盖结构位
candidate_mask = torch.isin(sampled[0], target_tensor) & ~struct_mask
# 对候选位按置信度排序,选出置信度最高的若干位揭开
cand_count = candidate_mask.sum().item()
reveal_count = max(0, int(cand_count) - num_to_mask)
next_state = current[0].clone()
if reveal_count > 0 and cand_count > 0:
cand_indices = candidate_mask.nonzero(as_tuple=True)[0]
cand_conf = confidences[0][cand_indices]
top_k = min(reveal_count, cand_conf.size(0))
_, top_idx = torch.topk(cand_conf, k=top_k, largest=True)
reveal_indices = cand_indices[top_idx]
next_state[reveal_indices] = sampled[0][reveal_indices]
# 后处理:进度未超 75% 时,随机将新揭开位的 20%-40% 再次掩码,
# 抑制目标图块过密生成;后期不再压制,确保最终能全部揭开
if step / steps <= 0.75:
suppress_ratio = random.uniform(0.2, 0.4)
suppress_k = max(1, int(reveal_indices.size(0) * suppress_ratio))
suppress_perm = torch.randperm(
reveal_indices.size(0), device=inp.device
)[:suppress_k]
next_state[reveal_indices[suppress_perm]] = MASK_TOKEN
# 结构位原样保留,其余未揭开的置为 MASK
non_struct_non_revealed = (next_state == current[0]) & ~struct_mask
next_state[non_struct_non_revealed & (next_state != MASK_TOKEN)] = MASK_TOKEN
# free 模式下,空地也重新标为 MASK允许下一步继续填充
if not keep_fixed:
next_state[next_state == 0] = MASK_TOKEN
current = next_state.unsqueeze(0)
if (current[0] == MASK_TOKEN).sum() == 0:
break
# 兜底:若仍有残余掩码位(理论上不应发生),用 argmax 确定性填充
# 兜底:若仍有残余掩码位,按模式填充
still_masked = (current[0] == MASK_TOKEN)
if still_masked.any():
logits = model(current, z, struct)
current[0, still_masked] = torch.argmax(logits[0, still_masked], dim=-1)
if has_target:
# 目标模式下,未被填充的位置视为空地(不属于本阶段负责的图块)
current[0, still_masked] = 0
else:
logits = model(current, z, struct)
current[0, still_masked] = torch.argmax(logits[0, still_masked], dim=-1)
return current[0].cpu().numpy().reshape(MAP_H, MAP_W)
@ -241,15 +286,19 @@ def full_generate_random_z(
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
inp2[inp2 == 0] = MASK_TOKEN # 空地位交由 stage2 填充
# stage2在骨架上生成 door/monster/entrance非零结果覆盖合并
pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[1])
# stage2在骨架上生成 door(2)/monster(4)/entrance(5),非零结果覆盖合并
pred2_np = maskgit_sample(
mg2, inp2, z, struct, GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1]
)
merged12 = pred1_np.copy()
merged12[pred2_np != 0] = pred2_np[pred2_np != 0]
inp3 = torch.tensor(merged12.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
inp3[inp3 == 0] = MASK_TOKEN
# stage3填充 resource
pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[2])
# stage3填充 resource(3)
pred3_np = maskgit_sample(
mg3, inp3, z, struct, GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2]
)
merged123 = merged12.copy()
merged123[pred3_np != 0] = pred3_np[pred3_np != 0]
@ -271,13 +320,17 @@ def full_generate_specific_z(
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
inp2[inp2 == 0] = MASK_TOKEN
pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[1])
pred2_np = maskgit_sample(
mg2, inp2, z, struct, GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1]
)
merged12 = pred1_np.copy()
merged12[pred2_np != 0] = pred2_np[pred2_np != 0]
inp3 = torch.tensor(merged12.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
inp3[inp3 == 0] = MASK_TOKEN
pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[2])
pred3_np = maskgit_sample(
mg3, inp3, z, struct, GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2]
)
merged123 = merged12.copy()
merged123[pred3_np != 0] = pred3_np[pred3_np != 0]