feat: 图片输出添加标签说明

This commit is contained in:
unanmed 2026-04-27 15:31:50 +08:00
parent 638c107884
commit 9bddb05625

View File

@ -49,7 +49,7 @@ VQ_K = 2 # codebook 大小
VQ_D_Z = 64 # codebook 嵌入维度 VQ_D_Z = 64 # codebook 嵌入维度
VQ_D_MODEL= 128 VQ_D_MODEL= 128
VQ_NHEAD = 4 VQ_NHEAD = 4
VQ_LAYERS = 2 VQ_LAYERS = 3
VQ_DIM_FF = 256 VQ_DIM_FF = 256
VQ_BETA = 0.25 # commit loss 权重 VQ_BETA = 0.25 # commit loss 权重
VQ_GAMMA = 0.1 # entropy loss 权重 VQ_GAMMA = 0.1 # entropy loss 权重
@ -224,6 +224,56 @@ def label_image(img: np.ndarray, text: str, font_scale: float = 0.45) -> np.ndar
return np.concatenate([bar, img], axis=0) return np.concatenate([bar, img], axis=0)
def struct_cond_to_text(sc: torch.Tensor) -> str:
"""
struct_cond [4] LongTensor 解码为可读字符串
sc 顺序[cond_sym, cond_room, cond_branch, cond_outer]
cond_sym : sym_h*4 + sym_v*2 + sym_c取值 0-67=null
cond_room : roomCountLevel 0-23=null
cond_branch: branchLevel 0-23=null
cond_outer : outerWall 0-12=null
"""
sym_val, room_val, branch_val, outer_val = (int(x) for x in sc.tolist())
# 对称性
if sym_val == 7:
sym_str = "sym:-"
else:
flags = []
if sym_val & 4: flags.append("H")
if sym_val & 2: flags.append("V")
if sym_val & 1: flags.append("C")
sym_str = "sym:" + ("".join(flags) if flags else "none")
# 房间数量等级
room_map = {0: "room:lo", 1: "room:mid", 2: "room:hi", 3: "room:-"}
room_str = room_map.get(room_val, f"room:{room_val}")
# 分支等级
branch_map = {0: "br:lo", 1: "br:mid", 2: "br:hi", 3: "br:-"}
branch_str = branch_map.get(branch_val, f"br:{branch_val}")
# 外墙
outer_map = {0: "wall:N", 1: "wall:Y", 2: "wall:-"}
outer_str = outer_map.get(outer_val, f"wall:{outer_val}")
return f"{sym_str} {room_str} {branch_str} {outer_str}"
def annotate_struct(img: np.ndarray, sc: torch.Tensor) -> np.ndarray:
"""在图片底部追加一行结构标签注释(深蓝底白字)。"""
text = struct_cond_to_text(sc)
bar_h = 14
bar = np.full((bar_h, img.shape[1], 3), (60, 30, 10), dtype=np.uint8)
cv2.putText(
bar, text, (2, bar_h - 3),
cv2.FONT_HERSHEY_SIMPLEX, 0.38,
(180, 220, 255), 1, cv2.LINE_AA
)
return np.concatenate([img, bar], axis=0)
def make_random_wall_seed(ratio_min: float = 0.02, ratio_max: float = 0.08) -> torch.Tensor: def make_random_wall_seed(ratio_min: float = 0.02, ratio_max: float = 0.08) -> torch.Tensor:
""" """
在全 MASK 地图上随机放置少量墙壁作为推理种子用于完全随机生成场景 在全 MASK 地图上随机放置少量墙壁作为推理种子用于完全随机生成场景
@ -331,6 +381,13 @@ def validate(
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc) gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
# 对使用了真实 struct_cond 的图片追加标签注释
sc0 = sc[0]
real_img = annotate_struct(real_img, sc0)
cond_img = annotate_struct(cond_img, sc0)
pred_img = annotate_struct(pred_img, sc0)
gen_r_img = annotate_struct(gen_r_img, sc0)
row = [real_img, cond_img, pred_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) row = [real_img, cond_img, pred_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
cv2.imwrite(f"{epoch_dir}/scene1_completion.png", grid_images(row)) cv2.imwrite(f"{epoch_dir}/scene1_completion.png", grid_images(row))
@ -344,6 +401,11 @@ def validate(
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc) gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
sc0 = sc[0]
real_img = annotate_struct(real_img, sc0)
cond_img = annotate_struct(cond_img, sc0)
gen_r_img = annotate_struct(gen_r_img, sc0)
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
cv2.imwrite(f"{epoch_dir}/scene2_wall.png", grid_images(row)) cv2.imwrite(f"{epoch_dir}/scene2_wall.png", grid_images(row))
@ -357,6 +419,11 @@ def validate(
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc) gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
sc0 = sc[0]
real_img = annotate_struct(real_img, sc0)
cond_img = annotate_struct(cond_img, sc0)
gen_r_img = annotate_struct(gen_r_img, sc0)
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
cv2.imwrite(f"{epoch_dir}/scene3_sparse.png", grid_images(row)) cv2.imwrite(f"{epoch_dir}/scene3_sparse.png", grid_images(row))
@ -370,6 +437,11 @@ def validate(
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc) gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
sc0 = sc[0]
real_img = annotate_struct(real_img, sc0)
cond_img = annotate_struct(cond_img, sc0)
gen_r_img = annotate_struct(gen_r_img, sc0)
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
cv2.imwrite(f"{epoch_dir}/scene4_entrance.png", grid_images(row)) cv2.imwrite(f"{epoch_dir}/scene4_entrance.png", grid_images(row))