From 21315a6cb053111a78e6f037d38b2b312157ffd8 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Mon, 27 Apr 2026 13:46:48 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E8=B0=83=E6=95=B4=E8=B6=85=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E4=B8=8E=E7=BB=86=E8=8A=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Copilot --- docs/vqvae-maskgit-design.md | 8 ++--- ginka/dataset.py | 24 ++++++++------ ginka/train_vq.py | 64 ++++++++++++++++++++++++++++-------- ginka/vqvae/model.py | 2 +- 4 files changed, 69 insertions(+), 29 deletions(-) diff --git a/docs/vqvae-maskgit-design.md b/docs/vqvae-maskgit-design.md index ede2202..a2b505e 100644 --- a/docs/vqvae-maskgit-design.md +++ b/docs/vqvae-maskgit-design.md @@ -391,7 +391,7 @@ $$\mathcal{L} = \mathcal{L}_{CE}(\text{MaskGIT}) + \beta \cdot \mathcal{L}_{comm - [x] 确定是否保留标量 cond(暂不使用) - [x] 确定训练子集划分与各场景比例 - [x] 实现 VQ-VAE 编码器模块(Transformer + VQ) -- [ ] 改造 `GinkaMaskGIT.forward()` 接受 z,替换热力图分支 -- [ ] 实现四种子集采样逻辑(`dataset.py` 新增多子集 Dataset 或采样权重) -- [ ] 实现子集 B/C/D 的输入构造函数(按规则清除 tile、保留墙壁/入口) -- [ ] 编写联合训练脚本,整合 VQ 损失与 MaskGIT 交叉熵损失 +- [x] 改造 `GinkaMaskGIT.forward()` 接受 z,替换热力图分支 +- [x] 实现四种子集采样逻辑(`dataset.py` 新增多子集 Dataset 或采样权重) +- [x] 实现子集 B/C/D 的输入构造函数(按规则清除 tile、保留墙壁/入口) +- [x] 编写联合训练脚本,整合 VQ 损失与 MaskGIT 交叉熵损失 diff --git a/ginka/dataset.py b/ginka/dataset.py index beb4b6a..f79debf 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -364,16 +364,16 @@ class GinkaVQDataset(Dataset): return flat elif subset == 'B': - # 仅保留 floor(0) 和 wall(1) + # 仅保留 wall(1),floor(0) 和其他非墙内容全部 mask flat = raw.reshape(-1).copy() - keep = (flat == self.FLOOR) | (flat == self.WALL) + keep = (flat == self.WALL) flat[~keep] = self.MASK_ID return flat elif subset == 'C': # Subset B + 随机 mask 部分 wall flat = raw.reshape(-1).copy() - keep = (flat == self.FLOOR) | (flat == self.WALL) + keep = (flat == self.WALL) flat[~keep] = self.MASK_ID wall_idx = np.where(flat == self.WALL)[0] @@ -385,14 +385,18 @@ class GinkaVQDataset(Dataset): return flat else: # D - # 仅保留 floor(0)、wall(1) 和 entrance(10) + # 仅保留 wall(1) 和 entrance(10),floor(0) 和其他非墙内容全部 mask flat = raw.reshape(-1).copy() - keep = ( - (flat == self.FLOOR) - | (flat == self.WALL) - | (flat == self.ENTRANCE) - ) + keep = (flat == self.WALL) | (flat == self.ENTRANCE) flat[~keep] = self.MASK_ID + + # 随机 mask 部分 wall(模拟真实场景,与子集 C 一致) + wall_idx = np.where(flat == self.WALL)[0] + if len(wall_idx) > 0: + ratio = random.random() * self.wall_mask_ratio + n = max(1, int(len(wall_idx) * ratio)) + chosen = np.random.choice(wall_idx, n, replace=False) + flat[chosen] = self.MASK_ID return flat def __getitem__(self, idx): @@ -429,4 +433,4 @@ if __name__ == "__main__": print(f"masked_map shape={masked.shape}, dtype={masked.dtype}") print(f"target_map shape={target.shape}, dtype={target.dtype}") print(f"被 mask 的位置数: {(masked == 15).sum().item()} / {masked.numel()}") - print(f"\n200 次采样子集分布: {subset_count}") \ No newline at end of file + print(f"\n200 次采样子集分布: {subset_count}") diff --git a/ginka/train_vq.py b/ginka/train_vq.py index 3035a9d..65eda30 100644 --- a/ginka/train_vq.py +++ b/ginka/train_vq.py @@ -44,8 +44,8 @@ MAP_H = MAP_W = 13 LABEL_SMOOTHING = 0.0 # VQ-VAE 超参 -VQ_L = 2 # summary token 数量(即 z 的序列长度) -VQ_K = 16 # codebook 大小 +VQ_L = 64 # summary token 数量(即 z 的序列长度) +VQ_K = 1 # codebook 大小 VQ_D_Z = 64 # codebook 嵌入维度 VQ_D_MODEL= 128 VQ_NHEAD = 4 @@ -180,6 +180,36 @@ def hstack_images(imgs: list, gap: int = 4, color=(255, 255, 255)) -> np.ndarray return result +def grid_images(imgs: list, gap: int = 4, bg_color=(255, 255, 255)) -> np.ndarray: + """将图片列表排成两行网格(上行 ceil(N/2),下行 floor(N/2)),方便查看。""" + n = len(imgs) + if n == 0: + return np.zeros((1, 1, 3), dtype=np.uint8) + if n == 1: + return imgs[0] + + mid = math.ceil(n / 2) + top_row = hstack_images(imgs[:mid], gap, bg_color) + bot_imgs = imgs[mid:] + + if not bot_imgs: + return top_row + + bot_row = hstack_images(bot_imgs, gap, bg_color) + + # 补齐宽度(右侧填充背景色) + tw, bw = top_row.shape[1], bot_row.shape[1] + if tw > bw: + pad = np.full((bot_row.shape[0], tw - bw, 3), bg_color, dtype=np.uint8) + bot_row = np.concatenate([bot_row, pad], axis=1) + elif bw > tw: + pad = np.full((top_row.shape[0], bw - tw, 3), bg_color, dtype=np.uint8) + top_row = np.concatenate([top_row, pad], axis=1) + + hline = np.full((gap, top_row.shape[1], 3), bg_color, dtype=np.uint8) + return np.concatenate([top_row, hline, bot_row], axis=0) + + def label_image(img: np.ndarray, text: str, font_scale: float = 0.45) -> np.ndarray: """在图片顶部加一行文字标签(就地修改并返回)。""" bar_h = 16 @@ -248,7 +278,7 @@ def validate( subsets = batch["subset"] # list of str B = raw_map.shape[0] - z_q, _, vq_loss = model_vq(raw_map) + z_q, _, vq_loss, _, _ = model_vq(raw_map) logits = model_mg(masked_map, z_q) mask = (masked_map == MASK_TOKEN) @@ -298,7 +328,7 @@ def validate( gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") row = [real_img, cond_img, pred_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) - cv2.imwrite(f"{epoch_dir}/scene1_completion.png", hstack_images(row)) + cv2.imwrite(f"{epoch_dir}/scene1_completion.png", grid_images(row)) # ── 场景2:墙壁辅助生成(子集 B)───────────────────────────────────────── if captured['B'] is not None: @@ -311,7 +341,7 @@ def validate( gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) - cv2.imwrite(f"{epoch_dir}/scene2_wall.png", hstack_images(row)) + cv2.imwrite(f"{epoch_dir}/scene2_wall.png", grid_images(row)) # ── 场景3:稀疏墙壁条件生成(子集 C)──────────────────────────────────── if captured['C'] is not None: @@ -324,7 +354,7 @@ def validate( gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) - cv2.imwrite(f"{epoch_dir}/scene3_sparse.png", hstack_images(row)) + cv2.imwrite(f"{epoch_dir}/scene3_sparse.png", grid_images(row)) # ── 场景4:墙壁+入口条件生成(子集 D)─────────────────────────────────── if captured['D'] is not None: @@ -337,14 +367,14 @@ def validate( gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) - cv2.imwrite(f"{epoch_dir}/scene4_entrance.png", hstack_images(row)) + cv2.imwrite(f"{epoch_dir}/scene4_entrance.png", grid_images(row)) # ── 场景5:完全随机生成(无数据集参照)────────────────────────────────── # 随机稀疏墙壁种子,对应推理时"不提供任何条件,直接生成"的场景 rand_seed = make_random_wall_seed() # [1, MAP_SIZE] seed_img = label_image(make_map_image(rand_seed[0], tile_dict), "random seed") row = [seed_img] + _rand_gens(rand_seed, N_Z_SAMPLES + 1) # 多采一个 z 展示多样性 - cv2.imwrite(f"{epoch_dir}/scene5_random.png", hstack_images(row)) + cv2.imwrite(f"{epoch_dir}/scene5_random.png", grid_images(row)) avg_val_loss = val_loss_total / max(val_steps, 1) return avg_val_loss @@ -431,10 +461,12 @@ def train(): model_vq.train() model_mg.train() - loss_total = 0.0 - ce_total = 0.0 - vq_loss_total = 0.0 - subset_stats = {'A': 0, 'B': 0, 'C': 0, 'D': 0} + loss_total = 0.0 + ce_total = 0.0 + vq_loss_total = 0.0 + commit_total = 0.0 + entropy_total = 0.0 + subset_stats = {'A': 0, 'B': 0, 'C': 0, 'D': 0} for batch in tqdm(dataloader_train, leave=False, desc="Epoch Progress", disable=disable_tqdm): @@ -447,7 +479,7 @@ def train(): # ---- 前向传播 ---- # 1. VQ-VAE 编码真实地图 → z_q - z_q, _, vq_loss = model_vq(raw_map) # z_q: [B, L, d_z] + z_q, _, vq_loss, commit_loss, entropy_loss = model_vq(raw_map) # z_q: [B, L, d_z] # 2. MaskGIT 以掩码地图 + z 预测原始 tile logits = model_mg(masked_map, z_q) # [B, 169, C] @@ -471,6 +503,8 @@ def train(): loss_total += loss.detach().item() ce_total += masked_ce.detach().item() vq_loss_total += vq_loss.detach().item() + commit_total += commit_loss.detach().item() + entropy_total += entropy_loss.detach().item() scheduler.step() @@ -480,7 +514,9 @@ def train(): f"Epoch {epoch + 1:4d} | " f"Loss {loss_total/n:.5f} " f"CE {ce_total/n:.5f} " - f"VQ {vq_loss_total/n:.5f} | " + f"VQ {vq_loss_total/n:.5f} " + f"Commit {commit_total/n:.5f} " + f"Entropy {entropy_total/n:.5f} | " f"LR {scheduler.get_last_lr()[0]:.6f} | " f"Subsets {subset_stats}" ) diff --git a/ginka/vqvae/model.py b/ginka/vqvae/model.py index 79ee4c6..617a5f9 100644 --- a/ginka/vqvae/model.py +++ b/ginka/vqvae/model.py @@ -131,7 +131,7 @@ class GinkaVQVAE(nn.Module): z_q, indices, commit_loss, entropy_loss = self.vq(z_e) vq_loss = self.beta * commit_loss + self.gamma * entropy_loss - return z_q, indices, vq_loss + return z_q, indices, vq_loss, commit_loss, entropy_loss def sample(self, B: int, device: torch.device) -> torch.Tensor: """