diff --git a/ginka/train_seperated.py b/ginka/train_seperated.py index 8e164ec..eb6703e 100644 --- a/ginka/train_seperated.py +++ b/ginka/train_seperated.py @@ -180,11 +180,21 @@ def random_struct(device: torch.device) -> torch.Tensor: def maskgit_sample( model: torch.nn.Module, inp: torch.Tensor, z: torch.Tensor, - struct: torch.Tensor, steps: int, keep_fixed: bool = True + struct: torch.Tensor, steps: int, target_tiles: list[int] | None = None, + keep_fixed: bool = True ) -> np.ndarray: - # keep_fixed=True:锁定输入中已有的非掩码位,使上一阶段结构保持不变 - # keep_fixed=False:所有位置均可被模型自由重估,适合探索更多样的生成结果 + # target_tiles: 本阶段负责生成的图块 ID 列表;None 表示接受所有类别(stage1) + # keep_fixed=True:锁定输入中已有的非掩码/非空地位,使上一阶段结构保持不变 + # keep_fixed=False:结构位保留,但每步结束后空地重新标为 MASK(探索模式) + # + # 有 target_tiles 时的核心逻辑: + # 每步只从预测中选出置信度最高的若干 target_tile 候选揭开, + # 其余已有结构(墙/门等非空地非掩码)原样保留, + # 空地与掩码保持为 MASK,等待后续步骤继续填充。 current = inp.clone() + has_target = target_tiles is not None + if has_target: + target_tensor = torch.tensor(target_tiles, dtype=torch.long, device=inp.device) # 迭代去掩码:每步根据置信度分数重新决定掩码位置 for step in range(steps): @@ -200,27 +210,62 @@ def maskgit_sample( ratio = math.cos(((step + 1) / steps) * math.pi / 2) num_to_mask = math.floor(ratio * MAP_SIZE) - if keep_fixed: - # 输入中已有的非掩码位(来自上一阶段)保持不变 - fixed_mask = (current[0] != MASK_TOKEN) - sampled[0, fixed_mask] = current[0, fixed_mask] - confidences[0, fixed_mask] = 1.0 - - if num_to_mask > 0: - # 将置信度最低的位重新掩码,留待下一步重新预测 - _, mask_indices = torch.topk(confidences[0], k=num_to_mask, largest=False) - sampled[0].scatter_(0, mask_indices, MASK_TOKEN) - - current = sampled + if not has_target: + # stage1:无 target 约束,仅锁定 fixed 位(若 keep_fixed) + if keep_fixed: + fixed_mask = (current[0] != MASK_TOKEN) + sampled[0, fixed_mask] = current[0, fixed_mask] + confidences[0, fixed_mask] = 1.0 + if num_to_mask > 0: + _, mask_indices = torch.topk(confidences[0], k=num_to_mask, largest=False) + sampled[0].scatter_(0, mask_indices, MASK_TOKEN) + current = sampled + else: + # 有 target_tiles:基于当前 current 构建下一状态 + # 结构位:current 中非空地、非掩码的位置(来自上一阶段,始终保留) + struct_mask = (current[0] != MASK_TOKEN) & (current[0] != 0) + # 候选位:sampled 为目标图块且不覆盖结构位 + candidate_mask = torch.isin(sampled[0], target_tensor) & ~struct_mask + # 对候选位按置信度排序,选出置信度最高的若干位揭开 + cand_count = candidate_mask.sum().item() + reveal_count = max(0, int(cand_count) - num_to_mask) + next_state = current[0].clone() + if reveal_count > 0 and cand_count > 0: + cand_indices = candidate_mask.nonzero(as_tuple=True)[0] + cand_conf = confidences[0][cand_indices] + top_k = min(reveal_count, cand_conf.size(0)) + _, top_idx = torch.topk(cand_conf, k=top_k, largest=True) + reveal_indices = cand_indices[top_idx] + next_state[reveal_indices] = sampled[0][reveal_indices] + # 后处理:进度未超 75% 时,随机将新揭开位的 20%-40% 再次掩码, + # 抑制目标图块过密生成;后期不再压制,确保最终能全部揭开 + if step / steps <= 0.75: + suppress_ratio = random.uniform(0.2, 0.4) + suppress_k = max(1, int(reveal_indices.size(0) * suppress_ratio)) + suppress_perm = torch.randperm( + reveal_indices.size(0), device=inp.device + )[:suppress_k] + next_state[reveal_indices[suppress_perm]] = MASK_TOKEN + # 结构位原样保留,其余未揭开的置为 MASK + non_struct_non_revealed = (next_state == current[0]) & ~struct_mask + next_state[non_struct_non_revealed & (next_state != MASK_TOKEN)] = MASK_TOKEN + # free 模式下,空地也重新标为 MASK(允许下一步继续填充) + if not keep_fixed: + next_state[next_state == 0] = MASK_TOKEN + current = next_state.unsqueeze(0) if (current[0] == MASK_TOKEN).sum() == 0: break - # 兜底:若仍有残余掩码位(理论上不应发生),用 argmax 确定性填充 + # 兜底:若仍有残余掩码位,按模式填充 still_masked = (current[0] == MASK_TOKEN) if still_masked.any(): - logits = model(current, z, struct) - current[0, still_masked] = torch.argmax(logits[0, still_masked], dim=-1) + if has_target: + # 目标模式下,未被填充的位置视为空地(不属于本阶段负责的图块) + current[0, still_masked] = 0 + else: + logits = model(current, z, struct) + current[0, still_masked] = torch.argmax(logits[0, still_masked], dim=-1) return current[0].cpu().numpy().reshape(MAP_H, MAP_W) @@ -241,15 +286,19 @@ def full_generate_random_z( inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE) inp2[inp2 == 0] = MASK_TOKEN # 空地位交由 stage2 填充 - # stage2:在骨架上生成 door/monster/entrance,非零结果覆盖合并 - pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[1]) + # stage2:在骨架上生成 door(2)/monster(4)/entrance(5),非零结果覆盖合并 + pred2_np = maskgit_sample( + mg2, inp2, z, struct, GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1] + ) merged12 = pred1_np.copy() merged12[pred2_np != 0] = pred2_np[pred2_np != 0] inp3 = torch.tensor(merged12.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE) inp3[inp3 == 0] = MASK_TOKEN - # stage3:填充 resource - pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[2]) + # stage3:填充 resource(3) + pred3_np = maskgit_sample( + mg3, inp3, z, struct, GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2] + ) merged123 = merged12.copy() merged123[pred3_np != 0] = pred3_np[pred3_np != 0] @@ -271,13 +320,17 @@ def full_generate_specific_z( inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE) inp2[inp2 == 0] = MASK_TOKEN - pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[1]) + pred2_np = maskgit_sample( + mg2, inp2, z, struct, GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1] + ) merged12 = pred1_np.copy() merged12[pred2_np != 0] = pred2_np[pred2_np != 0] inp3 = torch.tensor(merged12.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE) inp3[inp3 == 0] = MASK_TOKEN - pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[2]) + pred3_np = maskgit_sample( + mg3, inp3, z, struct, GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2] + ) merged123 = merged12.copy() merged123[pred3_np != 0] = pred3_np[pred3_np != 0]