chore: 调整超参数与细节

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
unanmed 2026-04-27 13:46:48 +08:00
parent cc463f543e
commit 21315a6cb0
4 changed files with 69 additions and 29 deletions

View File

@ -391,7 +391,7 @@ $$\mathcal{L} = \mathcal{L}_{CE}(\text{MaskGIT}) + \beta \cdot \mathcal{L}_{comm
- [x] 确定是否保留标量 cond暂不使用 - [x] 确定是否保留标量 cond暂不使用
- [x] 确定训练子集划分与各场景比例 - [x] 确定训练子集划分与各场景比例
- [x] 实现 VQ-VAE 编码器模块Transformer + VQ - [x] 实现 VQ-VAE 编码器模块Transformer + VQ
- [ ] 改造 `GinkaMaskGIT.forward()` 接受 z替换热力图分支 - [x] 改造 `GinkaMaskGIT.forward()` 接受 z替换热力图分支
- [ ] 实现四种子集采样逻辑(`dataset.py` 新增多子集 Dataset 或采样权重) - [x] 实现四种子集采样逻辑(`dataset.py` 新增多子集 Dataset 或采样权重)
- [ ] 实现子集 B/C/D 的输入构造函数(按规则清除 tile、保留墙壁/入口) - [x] 实现子集 B/C/D 的输入构造函数(按规则清除 tile、保留墙壁/入口)
- [ ] 编写联合训练脚本,整合 VQ 损失与 MaskGIT 交叉熵损失 - [x] 编写联合训练脚本,整合 VQ 损失与 MaskGIT 交叉熵损失

View File

@ -364,16 +364,16 @@ class GinkaVQDataset(Dataset):
return flat return flat
elif subset == 'B': elif subset == 'B':
# 仅保留 floor(0) 和 wall(1) # 仅保留 wall(1)floor(0) 和其他非墙内容全部 mask
flat = raw.reshape(-1).copy() flat = raw.reshape(-1).copy()
keep = (flat == self.FLOOR) | (flat == self.WALL) keep = (flat == self.WALL)
flat[~keep] = self.MASK_ID flat[~keep] = self.MASK_ID
return flat return flat
elif subset == 'C': elif subset == 'C':
# Subset B + 随机 mask 部分 wall # Subset B + 随机 mask 部分 wall
flat = raw.reshape(-1).copy() flat = raw.reshape(-1).copy()
keep = (flat == self.FLOOR) | (flat == self.WALL) keep = (flat == self.WALL)
flat[~keep] = self.MASK_ID flat[~keep] = self.MASK_ID
wall_idx = np.where(flat == self.WALL)[0] wall_idx = np.where(flat == self.WALL)[0]
@ -385,14 +385,18 @@ class GinkaVQDataset(Dataset):
return flat return flat
else: # D else: # D
# 仅保留 floor(0)、wall(1) 和 entrance(10) # 仅保留 wall(1) 和 entrance(10)floor(0) 和其他非墙内容全部 mask
flat = raw.reshape(-1).copy() flat = raw.reshape(-1).copy()
keep = ( keep = (flat == self.WALL) | (flat == self.ENTRANCE)
(flat == self.FLOOR)
| (flat == self.WALL)
| (flat == self.ENTRANCE)
)
flat[~keep] = self.MASK_ID flat[~keep] = self.MASK_ID
# 随机 mask 部分 wall模拟真实场景与子集 C 一致)
wall_idx = np.where(flat == self.WALL)[0]
if len(wall_idx) > 0:
ratio = random.random() * self.wall_mask_ratio
n = max(1, int(len(wall_idx) * ratio))
chosen = np.random.choice(wall_idx, n, replace=False)
flat[chosen] = self.MASK_ID
return flat return flat
def __getitem__(self, idx): def __getitem__(self, idx):
@ -429,4 +433,4 @@ if __name__ == "__main__":
print(f"masked_map shape={masked.shape}, dtype={masked.dtype}") print(f"masked_map shape={masked.shape}, dtype={masked.dtype}")
print(f"target_map shape={target.shape}, dtype={target.dtype}") print(f"target_map shape={target.shape}, dtype={target.dtype}")
print(f"被 mask 的位置数: {(masked == 15).sum().item()} / {masked.numel()}") print(f"被 mask 的位置数: {(masked == 15).sum().item()} / {masked.numel()}")
print(f"\n200 次采样子集分布: {subset_count}") print(f"\n200 次采样子集分布: {subset_count}")

View File

@ -44,8 +44,8 @@ MAP_H = MAP_W = 13
LABEL_SMOOTHING = 0.0 LABEL_SMOOTHING = 0.0
# VQ-VAE 超参 # VQ-VAE 超参
VQ_L = 2 # summary token 数量(即 z 的序列长度) VQ_L = 64 # summary token 数量(即 z 的序列长度)
VQ_K = 16 # codebook 大小 VQ_K = 1 # codebook 大小
VQ_D_Z = 64 # codebook 嵌入维度 VQ_D_Z = 64 # codebook 嵌入维度
VQ_D_MODEL= 128 VQ_D_MODEL= 128
VQ_NHEAD = 4 VQ_NHEAD = 4
@ -180,6 +180,36 @@ def hstack_images(imgs: list, gap: int = 4, color=(255, 255, 255)) -> np.ndarray
return result return result
def grid_images(imgs: list, gap: int = 4, bg_color=(255, 255, 255)) -> np.ndarray:
"""将图片列表排成两行网格(上行 ceil(N/2),下行 floor(N/2)),方便查看。"""
n = len(imgs)
if n == 0:
return np.zeros((1, 1, 3), dtype=np.uint8)
if n == 1:
return imgs[0]
mid = math.ceil(n / 2)
top_row = hstack_images(imgs[:mid], gap, bg_color)
bot_imgs = imgs[mid:]
if not bot_imgs:
return top_row
bot_row = hstack_images(bot_imgs, gap, bg_color)
# 补齐宽度(右侧填充背景色)
tw, bw = top_row.shape[1], bot_row.shape[1]
if tw > bw:
pad = np.full((bot_row.shape[0], tw - bw, 3), bg_color, dtype=np.uint8)
bot_row = np.concatenate([bot_row, pad], axis=1)
elif bw > tw:
pad = np.full((top_row.shape[0], bw - tw, 3), bg_color, dtype=np.uint8)
top_row = np.concatenate([top_row, pad], axis=1)
hline = np.full((gap, top_row.shape[1], 3), bg_color, dtype=np.uint8)
return np.concatenate([top_row, hline, bot_row], axis=0)
def label_image(img: np.ndarray, text: str, font_scale: float = 0.45) -> np.ndarray: def label_image(img: np.ndarray, text: str, font_scale: float = 0.45) -> np.ndarray:
"""在图片顶部加一行文字标签(就地修改并返回)。""" """在图片顶部加一行文字标签(就地修改并返回)。"""
bar_h = 16 bar_h = 16
@ -248,7 +278,7 @@ def validate(
subsets = batch["subset"] # list of str subsets = batch["subset"] # list of str
B = raw_map.shape[0] B = raw_map.shape[0]
z_q, _, vq_loss = model_vq(raw_map) z_q, _, vq_loss, _, _ = model_vq(raw_map)
logits = model_mg(masked_map, z_q) logits = model_mg(masked_map, z_q)
mask = (masked_map == MASK_TOKEN) mask = (masked_map == MASK_TOKEN)
@ -298,7 +328,7 @@ def validate(
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
row = [real_img, cond_img, pred_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) row = [real_img, cond_img, pred_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
cv2.imwrite(f"{epoch_dir}/scene1_completion.png", hstack_images(row)) cv2.imwrite(f"{epoch_dir}/scene1_completion.png", grid_images(row))
# ── 场景2墙壁辅助生成子集 B───────────────────────────────────────── # ── 场景2墙壁辅助生成子集 B─────────────────────────────────────────
if captured['B'] is not None: if captured['B'] is not None:
@ -311,7 +341,7 @@ def validate(
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
cv2.imwrite(f"{epoch_dir}/scene2_wall.png", hstack_images(row)) cv2.imwrite(f"{epoch_dir}/scene2_wall.png", grid_images(row))
# ── 场景3稀疏墙壁条件生成子集 C──────────────────────────────────── # ── 场景3稀疏墙壁条件生成子集 C────────────────────────────────────
if captured['C'] is not None: if captured['C'] is not None:
@ -324,7 +354,7 @@ def validate(
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
cv2.imwrite(f"{epoch_dir}/scene3_sparse.png", hstack_images(row)) cv2.imwrite(f"{epoch_dir}/scene3_sparse.png", grid_images(row))
# ── 场景4墙壁+入口条件生成(子集 D─────────────────────────────────── # ── 场景4墙壁+入口条件生成(子集 D───────────────────────────────────
if captured['D'] is not None: if captured['D'] is not None:
@ -337,14 +367,14 @@ def validate(
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
cv2.imwrite(f"{epoch_dir}/scene4_entrance.png", hstack_images(row)) cv2.imwrite(f"{epoch_dir}/scene4_entrance.png", grid_images(row))
# ── 场景5完全随机生成无数据集参照────────────────────────────────── # ── 场景5完全随机生成无数据集参照──────────────────────────────────
# 随机稀疏墙壁种子,对应推理时"不提供任何条件,直接生成"的场景 # 随机稀疏墙壁种子,对应推理时"不提供任何条件,直接生成"的场景
rand_seed = make_random_wall_seed() # [1, MAP_SIZE] rand_seed = make_random_wall_seed() # [1, MAP_SIZE]
seed_img = label_image(make_map_image(rand_seed[0], tile_dict), "random seed") seed_img = label_image(make_map_image(rand_seed[0], tile_dict), "random seed")
row = [seed_img] + _rand_gens(rand_seed, N_Z_SAMPLES + 1) # 多采一个 z 展示多样性 row = [seed_img] + _rand_gens(rand_seed, N_Z_SAMPLES + 1) # 多采一个 z 展示多样性
cv2.imwrite(f"{epoch_dir}/scene5_random.png", hstack_images(row)) cv2.imwrite(f"{epoch_dir}/scene5_random.png", grid_images(row))
avg_val_loss = val_loss_total / max(val_steps, 1) avg_val_loss = val_loss_total / max(val_steps, 1)
return avg_val_loss return avg_val_loss
@ -431,10 +461,12 @@ def train():
model_vq.train() model_vq.train()
model_mg.train() model_mg.train()
loss_total = 0.0 loss_total = 0.0
ce_total = 0.0 ce_total = 0.0
vq_loss_total = 0.0 vq_loss_total = 0.0
subset_stats = {'A': 0, 'B': 0, 'C': 0, 'D': 0} commit_total = 0.0
entropy_total = 0.0
subset_stats = {'A': 0, 'B': 0, 'C': 0, 'D': 0}
for batch in tqdm(dataloader_train, leave=False, for batch in tqdm(dataloader_train, leave=False,
desc="Epoch Progress", disable=disable_tqdm): desc="Epoch Progress", disable=disable_tqdm):
@ -447,7 +479,7 @@ def train():
# ---- 前向传播 ---- # ---- 前向传播 ----
# 1. VQ-VAE 编码真实地图 → z_q # 1. VQ-VAE 编码真实地图 → z_q
z_q, _, vq_loss = model_vq(raw_map) # z_q: [B, L, d_z] z_q, _, vq_loss, commit_loss, entropy_loss = model_vq(raw_map) # z_q: [B, L, d_z]
# 2. MaskGIT 以掩码地图 + z 预测原始 tile # 2. MaskGIT 以掩码地图 + z 预测原始 tile
logits = model_mg(masked_map, z_q) # [B, 169, C] logits = model_mg(masked_map, z_q) # [B, 169, C]
@ -471,6 +503,8 @@ def train():
loss_total += loss.detach().item() loss_total += loss.detach().item()
ce_total += masked_ce.detach().item() ce_total += masked_ce.detach().item()
vq_loss_total += vq_loss.detach().item() vq_loss_total += vq_loss.detach().item()
commit_total += commit_loss.detach().item()
entropy_total += entropy_loss.detach().item()
scheduler.step() scheduler.step()
@ -480,7 +514,9 @@ def train():
f"Epoch {epoch + 1:4d} | " f"Epoch {epoch + 1:4d} | "
f"Loss {loss_total/n:.5f} " f"Loss {loss_total/n:.5f} "
f"CE {ce_total/n:.5f} " f"CE {ce_total/n:.5f} "
f"VQ {vq_loss_total/n:.5f} | " f"VQ {vq_loss_total/n:.5f} "
f"Commit {commit_total/n:.5f} "
f"Entropy {entropy_total/n:.5f} | "
f"LR {scheduler.get_last_lr()[0]:.6f} | " f"LR {scheduler.get_last_lr()[0]:.6f} | "
f"Subsets {subset_stats}" f"Subsets {subset_stats}"
) )

View File

@ -131,7 +131,7 @@ class GinkaVQVAE(nn.Module):
z_q, indices, commit_loss, entropy_loss = self.vq(z_e) z_q, indices, commit_loss, entropy_loss = self.vq(z_e)
vq_loss = self.beta * commit_loss + self.gamma * entropy_loss vq_loss = self.beta * commit_loss + self.gamma * entropy_loss
return z_q, indices, vq_loss return z_q, indices, vq_loss, commit_loss, entropy_loss
def sample(self, B: int, device: torch.device) -> torch.Tensor: def sample(self, B: int, device: torch.device) -> torch.Tensor:
""" """