mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-17 23:21:20 +08:00
chore: 调整超参数与细节
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
cc463f543e
commit
21315a6cb0
@ -391,7 +391,7 @@ $$\mathcal{L} = \mathcal{L}_{CE}(\text{MaskGIT}) + \beta \cdot \mathcal{L}_{comm
|
||||
- [x] 确定是否保留标量 cond(暂不使用)
|
||||
- [x] 确定训练子集划分与各场景比例
|
||||
- [x] 实现 VQ-VAE 编码器模块(Transformer + VQ)
|
||||
- [ ] 改造 `GinkaMaskGIT.forward()` 接受 z,替换热力图分支
|
||||
- [ ] 实现四种子集采样逻辑(`dataset.py` 新增多子集 Dataset 或采样权重)
|
||||
- [ ] 实现子集 B/C/D 的输入构造函数(按规则清除 tile、保留墙壁/入口)
|
||||
- [ ] 编写联合训练脚本,整合 VQ 损失与 MaskGIT 交叉熵损失
|
||||
- [x] 改造 `GinkaMaskGIT.forward()` 接受 z,替换热力图分支
|
||||
- [x] 实现四种子集采样逻辑(`dataset.py` 新增多子集 Dataset 或采样权重)
|
||||
- [x] 实现子集 B/C/D 的输入构造函数(按规则清除 tile、保留墙壁/入口)
|
||||
- [x] 编写联合训练脚本,整合 VQ 损失与 MaskGIT 交叉熵损失
|
||||
|
||||
@ -364,16 +364,16 @@ class GinkaVQDataset(Dataset):
|
||||
return flat
|
||||
|
||||
elif subset == 'B':
|
||||
# 仅保留 floor(0) 和 wall(1)
|
||||
# 仅保留 wall(1),floor(0) 和其他非墙内容全部 mask
|
||||
flat = raw.reshape(-1).copy()
|
||||
keep = (flat == self.FLOOR) | (flat == self.WALL)
|
||||
keep = (flat == self.WALL)
|
||||
flat[~keep] = self.MASK_ID
|
||||
return flat
|
||||
|
||||
elif subset == 'C':
|
||||
# Subset B + 随机 mask 部分 wall
|
||||
flat = raw.reshape(-1).copy()
|
||||
keep = (flat == self.FLOOR) | (flat == self.WALL)
|
||||
keep = (flat == self.WALL)
|
||||
flat[~keep] = self.MASK_ID
|
||||
|
||||
wall_idx = np.where(flat == self.WALL)[0]
|
||||
@ -385,14 +385,18 @@ class GinkaVQDataset(Dataset):
|
||||
return flat
|
||||
|
||||
else: # D
|
||||
# 仅保留 floor(0)、wall(1) 和 entrance(10)
|
||||
# 仅保留 wall(1) 和 entrance(10),floor(0) 和其他非墙内容全部 mask
|
||||
flat = raw.reshape(-1).copy()
|
||||
keep = (
|
||||
(flat == self.FLOOR)
|
||||
| (flat == self.WALL)
|
||||
| (flat == self.ENTRANCE)
|
||||
)
|
||||
keep = (flat == self.WALL) | (flat == self.ENTRANCE)
|
||||
flat[~keep] = self.MASK_ID
|
||||
|
||||
# 随机 mask 部分 wall(模拟真实场景,与子集 C 一致)
|
||||
wall_idx = np.where(flat == self.WALL)[0]
|
||||
if len(wall_idx) > 0:
|
||||
ratio = random.random() * self.wall_mask_ratio
|
||||
n = max(1, int(len(wall_idx) * ratio))
|
||||
chosen = np.random.choice(wall_idx, n, replace=False)
|
||||
flat[chosen] = self.MASK_ID
|
||||
return flat
|
||||
|
||||
def __getitem__(self, idx):
|
||||
@ -429,4 +433,4 @@ if __name__ == "__main__":
|
||||
print(f"masked_map shape={masked.shape}, dtype={masked.dtype}")
|
||||
print(f"target_map shape={target.shape}, dtype={target.dtype}")
|
||||
print(f"被 mask 的位置数: {(masked == 15).sum().item()} / {masked.numel()}")
|
||||
print(f"\n200 次采样子集分布: {subset_count}")
|
||||
print(f"\n200 次采样子集分布: {subset_count}")
|
||||
|
||||
@ -44,8 +44,8 @@ MAP_H = MAP_W = 13
|
||||
LABEL_SMOOTHING = 0.0
|
||||
|
||||
# VQ-VAE 超参
|
||||
VQ_L = 2 # summary token 数量(即 z 的序列长度)
|
||||
VQ_K = 16 # codebook 大小
|
||||
VQ_L = 64 # summary token 数量(即 z 的序列长度)
|
||||
VQ_K = 1 # codebook 大小
|
||||
VQ_D_Z = 64 # codebook 嵌入维度
|
||||
VQ_D_MODEL= 128
|
||||
VQ_NHEAD = 4
|
||||
@ -180,6 +180,36 @@ def hstack_images(imgs: list, gap: int = 4, color=(255, 255, 255)) -> np.ndarray
|
||||
return result
|
||||
|
||||
|
||||
def grid_images(imgs: list, gap: int = 4, bg_color=(255, 255, 255)) -> np.ndarray:
|
||||
"""将图片列表排成两行网格(上行 ceil(N/2),下行 floor(N/2)),方便查看。"""
|
||||
n = len(imgs)
|
||||
if n == 0:
|
||||
return np.zeros((1, 1, 3), dtype=np.uint8)
|
||||
if n == 1:
|
||||
return imgs[0]
|
||||
|
||||
mid = math.ceil(n / 2)
|
||||
top_row = hstack_images(imgs[:mid], gap, bg_color)
|
||||
bot_imgs = imgs[mid:]
|
||||
|
||||
if not bot_imgs:
|
||||
return top_row
|
||||
|
||||
bot_row = hstack_images(bot_imgs, gap, bg_color)
|
||||
|
||||
# 补齐宽度(右侧填充背景色)
|
||||
tw, bw = top_row.shape[1], bot_row.shape[1]
|
||||
if tw > bw:
|
||||
pad = np.full((bot_row.shape[0], tw - bw, 3), bg_color, dtype=np.uint8)
|
||||
bot_row = np.concatenate([bot_row, pad], axis=1)
|
||||
elif bw > tw:
|
||||
pad = np.full((top_row.shape[0], bw - tw, 3), bg_color, dtype=np.uint8)
|
||||
top_row = np.concatenate([top_row, pad], axis=1)
|
||||
|
||||
hline = np.full((gap, top_row.shape[1], 3), bg_color, dtype=np.uint8)
|
||||
return np.concatenate([top_row, hline, bot_row], axis=0)
|
||||
|
||||
|
||||
def label_image(img: np.ndarray, text: str, font_scale: float = 0.45) -> np.ndarray:
|
||||
"""在图片顶部加一行文字标签(就地修改并返回)。"""
|
||||
bar_h = 16
|
||||
@ -248,7 +278,7 @@ def validate(
|
||||
subsets = batch["subset"] # list of str
|
||||
B = raw_map.shape[0]
|
||||
|
||||
z_q, _, vq_loss = model_vq(raw_map)
|
||||
z_q, _, vq_loss, _, _ = model_vq(raw_map)
|
||||
logits = model_mg(masked_map, z_q)
|
||||
mask = (masked_map == MASK_TOKEN)
|
||||
|
||||
@ -298,7 +328,7 @@ def validate(
|
||||
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
||||
|
||||
row = [real_img, cond_img, pred_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
||||
cv2.imwrite(f"{epoch_dir}/scene1_completion.png", hstack_images(row))
|
||||
cv2.imwrite(f"{epoch_dir}/scene1_completion.png", grid_images(row))
|
||||
|
||||
# ── 场景2:墙壁辅助生成(子集 B)─────────────────────────────────────────
|
||||
if captured['B'] is not None:
|
||||
@ -311,7 +341,7 @@ def validate(
|
||||
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
||||
|
||||
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
||||
cv2.imwrite(f"{epoch_dir}/scene2_wall.png", hstack_images(row))
|
||||
cv2.imwrite(f"{epoch_dir}/scene2_wall.png", grid_images(row))
|
||||
|
||||
# ── 场景3:稀疏墙壁条件生成(子集 C)────────────────────────────────────
|
||||
if captured['C'] is not None:
|
||||
@ -324,7 +354,7 @@ def validate(
|
||||
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
||||
|
||||
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
||||
cv2.imwrite(f"{epoch_dir}/scene3_sparse.png", hstack_images(row))
|
||||
cv2.imwrite(f"{epoch_dir}/scene3_sparse.png", grid_images(row))
|
||||
|
||||
# ── 场景4:墙壁+入口条件生成(子集 D)───────────────────────────────────
|
||||
if captured['D'] is not None:
|
||||
@ -337,14 +367,14 @@ def validate(
|
||||
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
||||
|
||||
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
||||
cv2.imwrite(f"{epoch_dir}/scene4_entrance.png", hstack_images(row))
|
||||
cv2.imwrite(f"{epoch_dir}/scene4_entrance.png", grid_images(row))
|
||||
|
||||
# ── 场景5:完全随机生成(无数据集参照)──────────────────────────────────
|
||||
# 随机稀疏墙壁种子,对应推理时"不提供任何条件,直接生成"的场景
|
||||
rand_seed = make_random_wall_seed() # [1, MAP_SIZE]
|
||||
seed_img = label_image(make_map_image(rand_seed[0], tile_dict), "random seed")
|
||||
row = [seed_img] + _rand_gens(rand_seed, N_Z_SAMPLES + 1) # 多采一个 z 展示多样性
|
||||
cv2.imwrite(f"{epoch_dir}/scene5_random.png", hstack_images(row))
|
||||
cv2.imwrite(f"{epoch_dir}/scene5_random.png", grid_images(row))
|
||||
|
||||
avg_val_loss = val_loss_total / max(val_steps, 1)
|
||||
return avg_val_loss
|
||||
@ -431,10 +461,12 @@ def train():
|
||||
model_vq.train()
|
||||
model_mg.train()
|
||||
|
||||
loss_total = 0.0
|
||||
ce_total = 0.0
|
||||
vq_loss_total = 0.0
|
||||
subset_stats = {'A': 0, 'B': 0, 'C': 0, 'D': 0}
|
||||
loss_total = 0.0
|
||||
ce_total = 0.0
|
||||
vq_loss_total = 0.0
|
||||
commit_total = 0.0
|
||||
entropy_total = 0.0
|
||||
subset_stats = {'A': 0, 'B': 0, 'C': 0, 'D': 0}
|
||||
|
||||
for batch in tqdm(dataloader_train, leave=False,
|
||||
desc="Epoch Progress", disable=disable_tqdm):
|
||||
@ -447,7 +479,7 @@ def train():
|
||||
|
||||
# ---- 前向传播 ----
|
||||
# 1. VQ-VAE 编码真实地图 → z_q
|
||||
z_q, _, vq_loss = model_vq(raw_map) # z_q: [B, L, d_z]
|
||||
z_q, _, vq_loss, commit_loss, entropy_loss = model_vq(raw_map) # z_q: [B, L, d_z]
|
||||
|
||||
# 2. MaskGIT 以掩码地图 + z 预测原始 tile
|
||||
logits = model_mg(masked_map, z_q) # [B, 169, C]
|
||||
@ -471,6 +503,8 @@ def train():
|
||||
loss_total += loss.detach().item()
|
||||
ce_total += masked_ce.detach().item()
|
||||
vq_loss_total += vq_loss.detach().item()
|
||||
commit_total += commit_loss.detach().item()
|
||||
entropy_total += entropy_loss.detach().item()
|
||||
|
||||
scheduler.step()
|
||||
|
||||
@ -480,7 +514,9 @@ def train():
|
||||
f"Epoch {epoch + 1:4d} | "
|
||||
f"Loss {loss_total/n:.5f} "
|
||||
f"CE {ce_total/n:.5f} "
|
||||
f"VQ {vq_loss_total/n:.5f} | "
|
||||
f"VQ {vq_loss_total/n:.5f} "
|
||||
f"Commit {commit_total/n:.5f} "
|
||||
f"Entropy {entropy_total/n:.5f} | "
|
||||
f"LR {scheduler.get_last_lr()[0]:.6f} | "
|
||||
f"Subsets {subset_stats}"
|
||||
)
|
||||
|
||||
@ -131,7 +131,7 @@ class GinkaVQVAE(nn.Module):
|
||||
z_q, indices, commit_loss, entropy_loss = self.vq(z_e)
|
||||
|
||||
vq_loss = self.beta * commit_loss + self.gamma * entropy_loss
|
||||
return z_q, indices, vq_loss
|
||||
return z_q, indices, vq_loss, commit_loss, entropy_loss
|
||||
|
||||
def sample(self, B: int, device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user