diff --git a/ginka/train_vq.py b/ginka/train_vq.py index 364858a..f67e70f 100644 --- a/ginka/train_vq.py +++ b/ginka/train_vq.py @@ -42,6 +42,7 @@ GENERATE_STEP = 18 # 推理时 MaskGIT 迭代步数 MAP_SIZE = 13 * 13 MAP_H = MAP_W = 13 LABEL_SMOOTHING = 0.0 +WALL_MASK_RATIO = 0.8 # VQ-VAE 超参 VQ_L = 16 # summary token 数量(即 z 的序列长度) @@ -288,6 +289,21 @@ def make_random_wall_seed(ratio_min: float = 0.02, ratio_max: float = 0.08) -> t seed[0, idx] = 1 # wall return seed + +def make_random_struct_cond() -> torch.Tensor: + """ + 生成一个随机结构条件,所有标签均取合法非-null 值。 + + Returns: + [1, 4] LongTensor,顺序 [cond_sym, cond_room, cond_branch, cond_outer] + """ + from .maskGIT.model import SYM_VOCAB, ROOM_VOCAB, BRANCH_VOCAB, OUTER_VOCAB + sym = random.randint(0, SYM_VOCAB - 2) # 0-6 + room = random.randint(0, ROOM_VOCAB - 2) # 0-2 + branch = random.randint(0, BRANCH_VOCAB - 2) # 0-2 + outer = random.randint(0, OUTER_VOCAB - 2) # 0-1 + return torch.tensor([[sym, room, branch, outer]], dtype=torch.long, device=device) + @torch.no_grad() def validate( model_vq: GinkaVQVAE, @@ -356,7 +372,7 @@ def validate( if all(v is not None for v in captured.values()): break - # ── 公共辅助:对给定条件地图随机采样 n 次 z 并迭代生成 ────────────────── + # ── 公共辅助:对给定条件地图随机采样 n 次 z 并迭代生成(无条件)────────────── def _rand_gens(cond_map, n): imgs = [] for i in range(n): @@ -365,6 +381,18 @@ def validate( imgs.append(label_image(make_map_image(gen[0], tile_dict), f"z_rand_{i + 1}")) return imgs + # ── 公共辅助:对给定条件地图随机采样 n 次 z 并迭代生成(随机结构标签)──────── + def _rand_gens_with_struct(cond_map, n): + imgs = [] + for i in range(n): + z_r = model_vq.sample(1, device) + sc_r = make_random_struct_cond() # [1, 4] 随机合法标签 + gen = maskgit_generate(model_mg, z_r, init_map=cond_map, struct_cond=sc_r) + img = label_image(make_map_image(gen[0], tile_dict), f"z_rand_{i + 1}") + img = annotate_struct(img, sc_r[0]) + imgs.append(img) + return imgs + # ── 场景1:标准掩码补全(子集 A)───────────────────────────────────────── if captured['A'] is not None: cap = captured['A'] @@ -446,11 +474,17 @@ def validate( cv2.imwrite(f"{epoch_dir}/scene4_entrance.png", grid_images(row)) # ── 场景5:完全随机生成(无数据集参照)────────────────────────────────── - # 随机稀疏墙壁种子,对应推理时"不提供任何条件,直接生成"的场景 - rand_seed = make_random_wall_seed() # [1, MAP_SIZE] - seed_img = label_image(make_map_image(rand_seed[0], tile_dict), "random seed") - row = [seed_img] + _rand_gens(rand_seed, N_Z_SAMPLES + 1) # 多采一个 z 展示多样性 - cv2.imwrite(f"{epoch_dir}/scene5_random.png", grid_images(row)) + # 5a:随机结构标签 — 验证结构导向能力 + rand_seed_a = make_random_wall_seed() + seed_img_a = label_image(make_map_image(rand_seed_a[0], tile_dict), "random seed") + row_a = [seed_img_a] + _rand_gens_with_struct(rand_seed_a, N_Z_SAMPLES + 1) + cv2.imwrite(f"{epoch_dir}/scene5a_random_cond.png", grid_images(row_a)) + + # 5b:无条件(struct_cond=None)— 验证基线生成质量 + rand_seed_b = make_random_wall_seed() + seed_img_b = label_image(make_map_image(rand_seed_b[0], tile_dict), "random seed") + row_b = [seed_img_b] + _rand_gens(rand_seed_b, N_Z_SAMPLES + 1) + cv2.imwrite(f"{epoch_dir}/scene5b_random_uncond.png", grid_images(row_b)) avg_val_loss = val_loss_total / max(val_steps, 1) return avg_val_loss @@ -492,12 +526,14 @@ def train(): dataset_train = GinkaVQDataset( args.train, subset_weights=SUBSET_WEIGHTS, + wall_mask_ratio=WALL_MASK_RATIO, ) dataset_val = GinkaVQDataset( args.validate, subset_weights=SUBSET_WEIGHTS, room_thresholds=dataset_train.room_th, branch_thresholds=dataset_train.branch_th, + wall_mask_ratio=WALL_MASK_RATIO, ) dataloader_train = DataLoader( dataset_train, batch_size=BATCH_SIZE, shuffle=True,