chore: 调整超参数与细节

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
unanmed 2026-04-27 13:46:48 +08:00
parent cc463f543e
commit 21315a6cb0
4 changed files with 69 additions and 29 deletions

View File

@ -391,7 +391,7 @@ $$\mathcal{L} = \mathcal{L}_{CE}(\text{MaskGIT}) + \beta \cdot \mathcal{L}_{comm
- [x] 确定是否保留标量 cond暂不使用
- [x] 确定训练子集划分与各场景比例
- [x] 实现 VQ-VAE 编码器模块Transformer + VQ
- [ ] 改造 `GinkaMaskGIT.forward()` 接受 z替换热力图分支
- [ ] 实现四种子集采样逻辑(`dataset.py` 新增多子集 Dataset 或采样权重)
- [ ] 实现子集 B/C/D 的输入构造函数(按规则清除 tile、保留墙壁/入口)
- [ ] 编写联合训练脚本,整合 VQ 损失与 MaskGIT 交叉熵损失
- [x] 改造 `GinkaMaskGIT.forward()` 接受 z替换热力图分支
- [x] 实现四种子集采样逻辑(`dataset.py` 新增多子集 Dataset 或采样权重)
- [x] 实现子集 B/C/D 的输入构造函数(按规则清除 tile、保留墙壁/入口)
- [x] 编写联合训练脚本,整合 VQ 损失与 MaskGIT 交叉熵损失

View File

@ -364,16 +364,16 @@ class GinkaVQDataset(Dataset):
return flat
elif subset == 'B':
# 仅保留 floor(0) 和 wall(1)
# 仅保留 wall(1)floor(0) 和其他非墙内容全部 mask
flat = raw.reshape(-1).copy()
keep = (flat == self.FLOOR) | (flat == self.WALL)
keep = (flat == self.WALL)
flat[~keep] = self.MASK_ID
return flat
elif subset == 'C':
# Subset B + 随机 mask 部分 wall
flat = raw.reshape(-1).copy()
keep = (flat == self.FLOOR) | (flat == self.WALL)
keep = (flat == self.WALL)
flat[~keep] = self.MASK_ID
wall_idx = np.where(flat == self.WALL)[0]
@ -385,14 +385,18 @@ class GinkaVQDataset(Dataset):
return flat
else: # D
# 仅保留 floor(0)、wall(1) 和 entrance(10)
# 仅保留 wall(1) 和 entrance(10)floor(0) 和其他非墙内容全部 mask
flat = raw.reshape(-1).copy()
keep = (
(flat == self.FLOOR)
| (flat == self.WALL)
| (flat == self.ENTRANCE)
)
keep = (flat == self.WALL) | (flat == self.ENTRANCE)
flat[~keep] = self.MASK_ID
# 随机 mask 部分 wall模拟真实场景与子集 C 一致)
wall_idx = np.where(flat == self.WALL)[0]
if len(wall_idx) > 0:
ratio = random.random() * self.wall_mask_ratio
n = max(1, int(len(wall_idx) * ratio))
chosen = np.random.choice(wall_idx, n, replace=False)
flat[chosen] = self.MASK_ID
return flat
def __getitem__(self, idx):
@ -429,4 +433,4 @@ if __name__ == "__main__":
print(f"masked_map shape={masked.shape}, dtype={masked.dtype}")
print(f"target_map shape={target.shape}, dtype={target.dtype}")
print(f"被 mask 的位置数: {(masked == 15).sum().item()} / {masked.numel()}")
print(f"\n200 次采样子集分布: {subset_count}")
print(f"\n200 次采样子集分布: {subset_count}")

View File

@ -44,8 +44,8 @@ MAP_H = MAP_W = 13
LABEL_SMOOTHING = 0.0
# VQ-VAE 超参
VQ_L = 2 # summary token 数量(即 z 的序列长度)
VQ_K = 16 # codebook 大小
VQ_L = 64 # summary token 数量(即 z 的序列长度)
VQ_K = 1 # codebook 大小
VQ_D_Z = 64 # codebook 嵌入维度
VQ_D_MODEL= 128
VQ_NHEAD = 4
@ -180,6 +180,36 @@ def hstack_images(imgs: list, gap: int = 4, color=(255, 255, 255)) -> np.ndarray
return result
def grid_images(imgs: list, gap: int = 4, bg_color=(255, 255, 255)) -> np.ndarray:
"""将图片列表排成两行网格(上行 ceil(N/2),下行 floor(N/2)),方便查看。"""
n = len(imgs)
if n == 0:
return np.zeros((1, 1, 3), dtype=np.uint8)
if n == 1:
return imgs[0]
mid = math.ceil(n / 2)
top_row = hstack_images(imgs[:mid], gap, bg_color)
bot_imgs = imgs[mid:]
if not bot_imgs:
return top_row
bot_row = hstack_images(bot_imgs, gap, bg_color)
# 补齐宽度(右侧填充背景色)
tw, bw = top_row.shape[1], bot_row.shape[1]
if tw > bw:
pad = np.full((bot_row.shape[0], tw - bw, 3), bg_color, dtype=np.uint8)
bot_row = np.concatenate([bot_row, pad], axis=1)
elif bw > tw:
pad = np.full((top_row.shape[0], bw - tw, 3), bg_color, dtype=np.uint8)
top_row = np.concatenate([top_row, pad], axis=1)
hline = np.full((gap, top_row.shape[1], 3), bg_color, dtype=np.uint8)
return np.concatenate([top_row, hline, bot_row], axis=0)
def label_image(img: np.ndarray, text: str, font_scale: float = 0.45) -> np.ndarray:
"""在图片顶部加一行文字标签(就地修改并返回)。"""
bar_h = 16
@ -248,7 +278,7 @@ def validate(
subsets = batch["subset"] # list of str
B = raw_map.shape[0]
z_q, _, vq_loss = model_vq(raw_map)
z_q, _, vq_loss, _, _ = model_vq(raw_map)
logits = model_mg(masked_map, z_q)
mask = (masked_map == MASK_TOKEN)
@ -298,7 +328,7 @@ def validate(
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
row = [real_img, cond_img, pred_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
cv2.imwrite(f"{epoch_dir}/scene1_completion.png", hstack_images(row))
cv2.imwrite(f"{epoch_dir}/scene1_completion.png", grid_images(row))
# ── 场景2墙壁辅助生成子集 B─────────────────────────────────────────
if captured['B'] is not None:
@ -311,7 +341,7 @@ def validate(
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
cv2.imwrite(f"{epoch_dir}/scene2_wall.png", hstack_images(row))
cv2.imwrite(f"{epoch_dir}/scene2_wall.png", grid_images(row))
# ── 场景3稀疏墙壁条件生成子集 C────────────────────────────────────
if captured['C'] is not None:
@ -324,7 +354,7 @@ def validate(
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
cv2.imwrite(f"{epoch_dir}/scene3_sparse.png", hstack_images(row))
cv2.imwrite(f"{epoch_dir}/scene3_sparse.png", grid_images(row))
# ── 场景4墙壁+入口条件生成(子集 D───────────────────────────────────
if captured['D'] is not None:
@ -337,14 +367,14 @@ def validate(
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
cv2.imwrite(f"{epoch_dir}/scene4_entrance.png", hstack_images(row))
cv2.imwrite(f"{epoch_dir}/scene4_entrance.png", grid_images(row))
# ── 场景5完全随机生成无数据集参照──────────────────────────────────
# 随机稀疏墙壁种子,对应推理时"不提供任何条件,直接生成"的场景
rand_seed = make_random_wall_seed() # [1, MAP_SIZE]
seed_img = label_image(make_map_image(rand_seed[0], tile_dict), "random seed")
row = [seed_img] + _rand_gens(rand_seed, N_Z_SAMPLES + 1) # 多采一个 z 展示多样性
cv2.imwrite(f"{epoch_dir}/scene5_random.png", hstack_images(row))
cv2.imwrite(f"{epoch_dir}/scene5_random.png", grid_images(row))
avg_val_loss = val_loss_total / max(val_steps, 1)
return avg_val_loss
@ -431,10 +461,12 @@ def train():
model_vq.train()
model_mg.train()
loss_total = 0.0
ce_total = 0.0
vq_loss_total = 0.0
subset_stats = {'A': 0, 'B': 0, 'C': 0, 'D': 0}
loss_total = 0.0
ce_total = 0.0
vq_loss_total = 0.0
commit_total = 0.0
entropy_total = 0.0
subset_stats = {'A': 0, 'B': 0, 'C': 0, 'D': 0}
for batch in tqdm(dataloader_train, leave=False,
desc="Epoch Progress", disable=disable_tqdm):
@ -447,7 +479,7 @@ def train():
# ---- 前向传播 ----
# 1. VQ-VAE 编码真实地图 → z_q
z_q, _, vq_loss = model_vq(raw_map) # z_q: [B, L, d_z]
z_q, _, vq_loss, commit_loss, entropy_loss = model_vq(raw_map) # z_q: [B, L, d_z]
# 2. MaskGIT 以掩码地图 + z 预测原始 tile
logits = model_mg(masked_map, z_q) # [B, 169, C]
@ -471,6 +503,8 @@ def train():
loss_total += loss.detach().item()
ce_total += masked_ce.detach().item()
vq_loss_total += vq_loss.detach().item()
commit_total += commit_loss.detach().item()
entropy_total += entropy_loss.detach().item()
scheduler.step()
@ -480,7 +514,9 @@ def train():
f"Epoch {epoch + 1:4d} | "
f"Loss {loss_total/n:.5f} "
f"CE {ce_total/n:.5f} "
f"VQ {vq_loss_total/n:.5f} | "
f"VQ {vq_loss_total/n:.5f} "
f"Commit {commit_total/n:.5f} "
f"Entropy {entropy_total/n:.5f} | "
f"LR {scheduler.get_last_lr()[0]:.6f} | "
f"Subsets {subset_stats}"
)

View File

@ -131,7 +131,7 @@ class GinkaVQVAE(nn.Module):
z_q, indices, commit_loss, entropy_loss = self.vq(z_e)
vq_loss = self.beta * commit_loss + self.gamma * entropy_loss
return z_q, indices, vq_loss
return z_q, indices, vq_loss, commit_loss, entropy_loss
def sample(self, B: int, device: torch.device) -> torch.Tensor:
"""