mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 图片输出添加标签说明
This commit is contained in:
parent
638c107884
commit
9bddb05625
@ -49,7 +49,7 @@ VQ_K = 2 # codebook 大小
|
|||||||
VQ_D_Z = 64 # codebook 嵌入维度
|
VQ_D_Z = 64 # codebook 嵌入维度
|
||||||
VQ_D_MODEL= 128
|
VQ_D_MODEL= 128
|
||||||
VQ_NHEAD = 4
|
VQ_NHEAD = 4
|
||||||
VQ_LAYERS = 2
|
VQ_LAYERS = 3
|
||||||
VQ_DIM_FF = 256
|
VQ_DIM_FF = 256
|
||||||
VQ_BETA = 0.25 # commit loss 权重
|
VQ_BETA = 0.25 # commit loss 权重
|
||||||
VQ_GAMMA = 0.1 # entropy loss 权重
|
VQ_GAMMA = 0.1 # entropy loss 权重
|
||||||
@ -224,6 +224,56 @@ def label_image(img: np.ndarray, text: str, font_scale: float = 0.45) -> np.ndar
|
|||||||
return np.concatenate([bar, img], axis=0)
|
return np.concatenate([bar, img], axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
def struct_cond_to_text(sc: torch.Tensor) -> str:
|
||||||
|
"""
|
||||||
|
将 struct_cond [4] LongTensor 解码为可读字符串。
|
||||||
|
|
||||||
|
sc 顺序:[cond_sym, cond_room, cond_branch, cond_outer]
|
||||||
|
cond_sym : sym_h*4 + sym_v*2 + sym_c,取值 0-6,7=null
|
||||||
|
cond_room : roomCountLevel 0-2,3=null
|
||||||
|
cond_branch: branchLevel 0-2,3=null
|
||||||
|
cond_outer : outerWall 0-1,2=null
|
||||||
|
"""
|
||||||
|
sym_val, room_val, branch_val, outer_val = (int(x) for x in sc.tolist())
|
||||||
|
|
||||||
|
# 对称性
|
||||||
|
if sym_val == 7:
|
||||||
|
sym_str = "sym:-"
|
||||||
|
else:
|
||||||
|
flags = []
|
||||||
|
if sym_val & 4: flags.append("H")
|
||||||
|
if sym_val & 2: flags.append("V")
|
||||||
|
if sym_val & 1: flags.append("C")
|
||||||
|
sym_str = "sym:" + ("".join(flags) if flags else "none")
|
||||||
|
|
||||||
|
# 房间数量等级
|
||||||
|
room_map = {0: "room:lo", 1: "room:mid", 2: "room:hi", 3: "room:-"}
|
||||||
|
room_str = room_map.get(room_val, f"room:{room_val}")
|
||||||
|
|
||||||
|
# 分支等级
|
||||||
|
branch_map = {0: "br:lo", 1: "br:mid", 2: "br:hi", 3: "br:-"}
|
||||||
|
branch_str = branch_map.get(branch_val, f"br:{branch_val}")
|
||||||
|
|
||||||
|
# 外墙
|
||||||
|
outer_map = {0: "wall:N", 1: "wall:Y", 2: "wall:-"}
|
||||||
|
outer_str = outer_map.get(outer_val, f"wall:{outer_val}")
|
||||||
|
|
||||||
|
return f"{sym_str} {room_str} {branch_str} {outer_str}"
|
||||||
|
|
||||||
|
|
||||||
|
def annotate_struct(img: np.ndarray, sc: torch.Tensor) -> np.ndarray:
|
||||||
|
"""在图片底部追加一行结构标签注释(深蓝底白字)。"""
|
||||||
|
text = struct_cond_to_text(sc)
|
||||||
|
bar_h = 14
|
||||||
|
bar = np.full((bar_h, img.shape[1], 3), (60, 30, 10), dtype=np.uint8)
|
||||||
|
cv2.putText(
|
||||||
|
bar, text, (2, bar_h - 3),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.38,
|
||||||
|
(180, 220, 255), 1, cv2.LINE_AA
|
||||||
|
)
|
||||||
|
return np.concatenate([img, bar], axis=0)
|
||||||
|
|
||||||
|
|
||||||
def make_random_wall_seed(ratio_min: float = 0.02, ratio_max: float = 0.08) -> torch.Tensor:
|
def make_random_wall_seed(ratio_min: float = 0.02, ratio_max: float = 0.08) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
在全 MASK 地图上随机放置少量墙壁作为推理种子,用于完全随机生成场景。
|
在全 MASK 地图上随机放置少量墙壁作为推理种子,用于完全随机生成场景。
|
||||||
@ -331,6 +381,13 @@ def validate(
|
|||||||
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
|
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
|
||||||
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
||||||
|
|
||||||
|
# 对使用了真实 struct_cond 的图片追加标签注释
|
||||||
|
sc0 = sc[0]
|
||||||
|
real_img = annotate_struct(real_img, sc0)
|
||||||
|
cond_img = annotate_struct(cond_img, sc0)
|
||||||
|
pred_img = annotate_struct(pred_img, sc0)
|
||||||
|
gen_r_img = annotate_struct(gen_r_img, sc0)
|
||||||
|
|
||||||
row = [real_img, cond_img, pred_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
row = [real_img, cond_img, pred_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
||||||
cv2.imwrite(f"{epoch_dir}/scene1_completion.png", grid_images(row))
|
cv2.imwrite(f"{epoch_dir}/scene1_completion.png", grid_images(row))
|
||||||
|
|
||||||
@ -344,6 +401,11 @@ def validate(
|
|||||||
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
|
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
|
||||||
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
||||||
|
|
||||||
|
sc0 = sc[0]
|
||||||
|
real_img = annotate_struct(real_img, sc0)
|
||||||
|
cond_img = annotate_struct(cond_img, sc0)
|
||||||
|
gen_r_img = annotate_struct(gen_r_img, sc0)
|
||||||
|
|
||||||
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
||||||
cv2.imwrite(f"{epoch_dir}/scene2_wall.png", grid_images(row))
|
cv2.imwrite(f"{epoch_dir}/scene2_wall.png", grid_images(row))
|
||||||
|
|
||||||
@ -357,6 +419,11 @@ def validate(
|
|||||||
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
|
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
|
||||||
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
||||||
|
|
||||||
|
sc0 = sc[0]
|
||||||
|
real_img = annotate_struct(real_img, sc0)
|
||||||
|
cond_img = annotate_struct(cond_img, sc0)
|
||||||
|
gen_r_img = annotate_struct(gen_r_img, sc0)
|
||||||
|
|
||||||
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
||||||
cv2.imwrite(f"{epoch_dir}/scene3_sparse.png", grid_images(row))
|
cv2.imwrite(f"{epoch_dir}/scene3_sparse.png", grid_images(row))
|
||||||
|
|
||||||
@ -370,6 +437,11 @@ def validate(
|
|||||||
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
|
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
|
||||||
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
||||||
|
|
||||||
|
sc0 = sc[0]
|
||||||
|
real_img = annotate_struct(real_img, sc0)
|
||||||
|
cond_img = annotate_struct(cond_img, sc0)
|
||||||
|
gen_r_img = annotate_struct(gen_r_img, sc0)
|
||||||
|
|
||||||
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
||||||
cv2.imwrite(f"{epoch_dir}/scene4_entrance.png", grid_images(row))
|
cv2.imwrite(f"{epoch_dir}/scene4_entrance.png", grid_images(row))
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user