From 7c85b2b8cb607106bdad86a19c46f29f3d29e3b7 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Wed, 13 May 2026 21:05:44 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E5=AE=8C=E5=96=84=E9=AA=8C=E8=AF=81?= =?UTF-8?q?=E6=B5=81=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_seperated.py | 88 ++++++++++++++++++++++++++++------------ 1 file changed, 61 insertions(+), 27 deletions(-) diff --git a/ginka/train_seperated.py b/ginka/train_seperated.py index 425e5f6..438d3a5 100644 --- a/ginka/train_seperated.py +++ b/ginka/train_seperated.py @@ -158,7 +158,7 @@ def build_model(device: torch.device): scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=MIN_LR) # 共用 VectorQuantizer:不参与梯度更新,仅在前向时做码本查表 - quantizer = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z) + quantizer = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device) return vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler @@ -181,8 +181,10 @@ def random_struct(device: torch.device) -> torch.Tensor: def maskgit_sample( model: torch.nn.Module, inp: torch.Tensor, z: torch.Tensor, - struct: torch.Tensor, steps: int + struct: torch.Tensor, steps: int, keep_fixed: bool = True ) -> np.ndarray: + # keep_fixed=True:锁定输入中已有的非掩码位,使上一阶段结构保持不变 + # keep_fixed=False:所有位置均可被模型自由重估,适合探索更多样的生成结果 current = inp.clone() # 迭代去掩码:每步根据置信度分数重新决定掩码位置 @@ -199,10 +201,11 @@ def maskgit_sample( ratio = math.cos(((step + 1) / steps) * math.pi / 2) num_to_mask = math.floor(ratio * MAP_SIZE) - # 输入中已有的非掩码位(来自上一阶段)保持不变 - fixed_mask = (current[0] != MASK_TOKEN) - sampled[0, fixed_mask] = current[0, fixed_mask] - confidences[0, fixed_mask] = 1.0 + if keep_fixed: + # 输入中已有的非掩码位(来自上一阶段)保持不变 + fixed_mask = (current[0] != MASK_TOKEN) + sampled[0, fixed_mask] = current[0, fixed_mask] + confidences[0, fixed_mask] = 1.0 if num_to_mask > 0: # 将置信度最低的位重新掩码,留待下一步重新预测 @@ -226,7 +229,8 @@ def full_generate_random_z( input: torch.Tensor, struct: torch.Tensor, models: list[torch.nn.Module], - device: torch.device + device: torch.device, + keep_fixed: tuple[bool, bool, bool] = (True, True, True) ) -> tuple: vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models @@ -234,19 +238,19 @@ def full_generate_random_z( z = quantizer.sample(1, VQ_L, device) # stage1:生成 floor/wall 骨架 - pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP) + pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP, keep_fixed=keep_fixed[0]) inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE) inp2[inp2 == 0] = MASK_TOKEN # 空地位交由 stage2 填充 # stage2:在骨架上生成 door/monster/entrance,非零结果覆盖合并 - pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP) + pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[1]) merged12 = pred1_np.copy() merged12[pred2_np != 0] = pred2_np[pred2_np != 0] inp3 = torch.tensor(merged12.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE) inp3[inp3 == 0] = MASK_TOKEN # stage3:填充 resource - pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP) + pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[2]) merged123 = merged12.copy() merged123[pred3_np != 0] = pred3_np[pred3_np != 0] @@ -257,28 +261,43 @@ def full_generate_specific_z( z: torch.Tensor, struct: torch.Tensor, models: list[torch.nn.Module], - device: torch.device + device: torch.device, + keep_fixed: tuple[bool, bool, bool] = (True, True, True) ) -> tuple: vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models with torch.no_grad(): # 与 full_generate_random_z 相同的三阶段级联,但使用给定的 z - pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP) + pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP, keep_fixed=keep_fixed[0]) inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE) inp2[inp2 == 0] = MASK_TOKEN - pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP) + pred2_np = maskgit_sample(mg2, inp2, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[1]) merged12 = pred1_np.copy() merged12[pred2_np != 0] = pred2_np[pred2_np != 0] inp3 = torch.tensor(merged12.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE) inp3[inp3 == 0] = MASK_TOKEN - pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP) + pred3_np = maskgit_sample(mg3, inp3, z, struct, GENERATE_STEP, keep_fixed=keep_fixed[2]) merged123 = merged12.copy() merged123[pred3_np != 0] = pred3_np[pred3_np != 0] return pred1_np, merged12, merged123 +def annotate(img: np.ndarray, text: str) -> np.ndarray: + # 在图片左上角叠加文字标注(黑色描边 + 白色填充,确保任意背景下可读) + img = img.copy() + cv2.putText(img, text, (2, 14), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 2) + cv2.putText(img, text, (2, 14), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1) + return img + +def rand_keep() -> tuple[bool, bool, bool]: + b = random.choice([True, False]) + return (b, b, b) + +def keep_label(kf: tuple[bool, bool, bool]) -> str: + return 'fix' if kf[0] else 'free' + # 验证可视化 part1:3×3 网格;行1=编码器输入,行2=掩码输入,行3=三阶段预测(合并) def visualize_part1(batch, logits1, logits2, logits3, tile_dict): SEP = 3 @@ -293,10 +312,6 @@ def visualize_part1(batch, logits1, logits2, logits3, tile_dict): pred2 = torch.argmax(logits2[0], dim=-1).cpu().numpy().reshape(MAP_H, MAP_W) pred3 = torch.argmax(logits3[0], dim=-1).cpu().numpy().reshape(MAP_H, MAP_W) - pred3_merged = pred1.copy() - pred3_merged[pred2 != 0] = pred2[pred2 != 0] - pred3_merged[pred3 != 0] = pred3[pred3 != 0] - enc1_np = batch["encoder_stage1"][0].numpy().reshape(MAP_H, MAP_W) enc2_np = batch["encoder_stage2"][0].numpy().reshape(MAP_H, MAP_W) enc3_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W) @@ -304,10 +319,18 @@ def visualize_part1(batch, logits1, logits2, logits3, tile_dict): inp2_np = batch["input_stage2"][0].numpy().reshape(MAP_H, MAP_W) inp3_np = batch["input_stage3"][0].numpy().reshape(MAP_H, MAP_W) + # 将各阶段掩码输入中的 MASK 位用模型预测值填充,保留非掩码位原值 + result1 = inp1_np.copy() + result1[inp1_np == MASK_TOKEN] = pred1[inp1_np == MASK_TOKEN] + result2 = inp2_np.copy() + result2[inp2_np == MASK_TOKEN] = pred2[inp2_np == MASK_TOKEN] + result3 = inp3_np.copy() + result3[inp3_np == MASK_TOKEN] = pred3[inp3_np == MASK_TOKEN] + rows = [ [to_img(enc1_np), to_img(enc2_np), to_img(enc3_np)], [to_img(inp1_np), to_img(inp2_np), to_img(inp3_np)], - [to_img(pred1), to_img(pred2), to_img(pred3_merged)], + [to_img(result1), to_img(result2), to_img(result3)], ] grid = np.ones((3 * img_h + 4 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255 for r, row in enumerate(rows): @@ -329,9 +352,14 @@ def visualize_part2(batch, z_q, models, device, tile_dict): inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE) struct_t = batch["struct_inject"][0:1].to(device) + kf = rand_keep() auto_pred1_np, auto_merged12, auto_merged123 = full_generate_specific_z( - inp1_t, z_q[0:1], struct_t, models, device + inp1_t, z_q[0:1], struct_t, models, device, keep_fixed=kf ) + kf_label = 'fix' if kf[0] else 'free' + label1 = f"s1:{kf_label}" + label2 = f"s2:{kf_label}" + label3 = f"s3:{kf_label}" enc1_np = batch["encoder_stage1"][0].numpy().reshape(MAP_H, MAP_W) enc2_np = batch["encoder_stage2"][0].numpy().reshape(MAP_H, MAP_W) @@ -340,7 +368,7 @@ def visualize_part2(batch, z_q, models, device, tile_dict): rows = [ [to_img(enc1_np), to_img(enc2_np), to_img(enc3_np)], - [to_img(inp1_np), to_img(auto_pred1_np), to_img(auto_merged12), to_img(auto_merged123)], + [to_img(inp1_np), annotate(to_img(auto_pred1_np), label1), annotate(to_img(auto_merged12), label2), annotate(to_img(auto_merged123), label3)], ] grid = np.ones((2 * img_h + 3 * SEP, 4 * img_w + 5 * SEP, 3), dtype=np.uint8) * 255 for r, row in enumerate(rows): @@ -366,13 +394,15 @@ def visualize_part3(batch, models, device, tile_dict): row1 = [to_img(inp1_np)] for _ in range(2): - _, _, merged123 = full_generate_random_z(inp1_t, struct_ref, models, device) - row1.append(to_img(merged123)) + kf = rand_keep() + _, _, merged123 = full_generate_random_z(inp1_t, struct_ref, models, device, keep_fixed=kf) + row1.append(annotate(to_img(merged123), keep_label(kf))) row2 = [] for _ in range(3): - _, _, merged123 = full_generate_random_z(inp1_t, random_struct(device), models, device) - row2.append(to_img(merged123)) + kf = rand_keep() + _, _, merged123 = full_generate_random_z(inp1_t, random_struct(device), models, device, keep_fixed=kf) + row2.append(annotate(to_img(merged123), keep_label(kf))) rows = [row1, row2] grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255 @@ -401,8 +431,9 @@ def visualize_part4(models, device, tile_dict): results = [] for _ in range(5): - _, _, merged123 = full_generate_random_z(seed, random_struct(device), models, device) - results.append(to_img(merged123)) + kf = rand_keep() + _, _, merged123 = full_generate_random_z(seed, random_struct(device), models, device, keep_fixed=kf) + results.append(annotate(to_img(merged123), keep_label(kf))) row1 = [to_img(seed_np)] + results[:2] row2 = results[2:] @@ -664,3 +695,6 @@ def train(device: torch.device): "scheduler": scheduler.state_dict(), }, final_path) tqdm.write(f"Training complete. Final model saved: {final_path}") + +if __name__ == "__main__": + train(device)