From 9bddb05625e05ed65ed5818e891a1fa0c3c98a28 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Mon, 27 Apr 2026 15:31:50 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=9B=BE=E7=89=87=E8=BE=93=E5=87=BA?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=A0=87=E7=AD=BE=E8=AF=B4=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_vq.py | 74 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 73 insertions(+), 1 deletion(-) diff --git a/ginka/train_vq.py b/ginka/train_vq.py index 269e0c5..364858a 100644 --- a/ginka/train_vq.py +++ b/ginka/train_vq.py @@ -49,7 +49,7 @@ VQ_K = 2 # codebook 大小 VQ_D_Z = 64 # codebook 嵌入维度 VQ_D_MODEL= 128 VQ_NHEAD = 4 -VQ_LAYERS = 2 +VQ_LAYERS = 3 VQ_DIM_FF = 256 VQ_BETA = 0.25 # commit loss 权重 VQ_GAMMA = 0.1 # entropy loss 权重 @@ -224,6 +224,56 @@ def label_image(img: np.ndarray, text: str, font_scale: float = 0.45) -> np.ndar return np.concatenate([bar, img], axis=0) +def struct_cond_to_text(sc: torch.Tensor) -> str: + """ + 将 struct_cond [4] LongTensor 解码为可读字符串。 + + sc 顺序:[cond_sym, cond_room, cond_branch, cond_outer] + cond_sym : sym_h*4 + sym_v*2 + sym_c,取值 0-6,7=null + cond_room : roomCountLevel 0-2,3=null + cond_branch: branchLevel 0-2,3=null + cond_outer : outerWall 0-1,2=null + """ + sym_val, room_val, branch_val, outer_val = (int(x) for x in sc.tolist()) + + # 对称性 + if sym_val == 7: + sym_str = "sym:-" + else: + flags = [] + if sym_val & 4: flags.append("H") + if sym_val & 2: flags.append("V") + if sym_val & 1: flags.append("C") + sym_str = "sym:" + ("".join(flags) if flags else "none") + + # 房间数量等级 + room_map = {0: "room:lo", 1: "room:mid", 2: "room:hi", 3: "room:-"} + room_str = room_map.get(room_val, f"room:{room_val}") + + # 分支等级 + branch_map = {0: "br:lo", 1: "br:mid", 2: "br:hi", 3: "br:-"} + branch_str = branch_map.get(branch_val, f"br:{branch_val}") + + # 外墙 + outer_map = {0: "wall:N", 1: "wall:Y", 2: "wall:-"} + outer_str = outer_map.get(outer_val, f"wall:{outer_val}") + + return f"{sym_str} {room_str} {branch_str} {outer_str}" + + +def annotate_struct(img: np.ndarray, sc: torch.Tensor) -> np.ndarray: + """在图片底部追加一行结构标签注释(深蓝底白字)。""" + text = struct_cond_to_text(sc) + bar_h = 14 + bar = np.full((bar_h, img.shape[1], 3), (60, 30, 10), dtype=np.uint8) + cv2.putText( + bar, text, (2, bar_h - 3), + cv2.FONT_HERSHEY_SIMPLEX, 0.38, + (180, 220, 255), 1, cv2.LINE_AA + ) + return np.concatenate([img, bar], axis=0) + + def make_random_wall_seed(ratio_min: float = 0.02, ratio_max: float = 0.08) -> torch.Tensor: """ 在全 MASK 地图上随机放置少量墙壁作为推理种子,用于完全随机生成场景。 @@ -331,6 +381,13 @@ def validate( gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc) gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") + # 对使用了真实 struct_cond 的图片追加标签注释 + sc0 = sc[0] + real_img = annotate_struct(real_img, sc0) + cond_img = annotate_struct(cond_img, sc0) + pred_img = annotate_struct(pred_img, sc0) + gen_r_img = annotate_struct(gen_r_img, sc0) + row = [real_img, cond_img, pred_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) cv2.imwrite(f"{epoch_dir}/scene1_completion.png", grid_images(row)) @@ -344,6 +401,11 @@ def validate( gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc) gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") + sc0 = sc[0] + real_img = annotate_struct(real_img, sc0) + cond_img = annotate_struct(cond_img, sc0) + gen_r_img = annotate_struct(gen_r_img, sc0) + row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) cv2.imwrite(f"{epoch_dir}/scene2_wall.png", grid_images(row)) @@ -357,6 +419,11 @@ def validate( gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc) gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") + sc0 = sc[0] + real_img = annotate_struct(real_img, sc0) + cond_img = annotate_struct(cond_img, sc0) + gen_r_img = annotate_struct(gen_r_img, sc0) + row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) cv2.imwrite(f"{epoch_dir}/scene3_sparse.png", grid_images(row)) @@ -370,6 +437,11 @@ def validate( gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc) gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") + sc0 = sc[0] + real_img = annotate_struct(real_img, sc0) + cond_img = annotate_struct(cond_img, sc0) + gen_r_img = annotate_struct(gen_r_img, sc0) + row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) cv2.imwrite(f"{epoch_dir}/scene4_entrance.png", grid_images(row))