diff --git a/docs/z-improvement-design.md b/docs/z-improvement-design.md new file mode 100644 index 0000000..de7c366 --- /dev/null +++ b/docs/z-improvement-design.md @@ -0,0 +1,212 @@ +# z 控制力改进方案设计文档 + +## 背景 + +当前 VQ-VAE + MaskGIT 联合训练方案中,VQ-VAE 编码器将完整地图压缩为离散隐变量 z,z 经 Cross-Attention 注入 MaskGIT。实践中发现 z 对生成结果的控制力仍然不足——即使给定相同的 z,生成结果与原地图的结构相似度也偏低。 + +本文档整理以下两个可行改进方向,并详细分析各自的技术路径与取舍: + +| 方案 | 核心思路 | 状态 | +| ------ | -------------------------------------------------- | -------- | +| 方案 A | 重建一致性约束:将生成结果回送编码器,令 z 闭环 | 待细化 | +| 方案 B | 多路分拆编码:将地图按层次结构分拆为多部分分别编码 | 待细化 | +| 方案 C | 多阶段生成:先墙壁,再门怪,最后资源 | 后续计划 | + +--- + +## 方案 A:重建一致性约束(z 闭环) + +### 核心思路 + +在联合训练过程中,对 MaskGIT 生成的地图再次走一遍编码器,得到 z_pred,然后约束: + +$$\mathcal{L}_{consist} = \| z - z_{pred} \|^2$$ + +如果 z_pred 能与原始 z 对齐,说明模型学到了"z 控制 → 生成结果 → 编码 → 同一个 z"的完整回路,从而强化 z 对生成内容的控制力。 + +### 核心难点:离散地图的梯度传递 + +MaskGIT 输出的是每个位置的分类 logits,最终生成的地图是通过 argmax(或 Gumbel 采样)得到的**离散整数矩阵**。离散采样不可微,梯度无法直接通过 argmax 回传到 MaskGIT。 + +以下是三种可行的梯度传递方案: + +#### 方案 A-1:软分布近似(Soft Logits 送入编码器) + +跳过 argmax,将 MaskGIT 输出的原始 logits(或 softmax 概率分布)直接送入编码器,替代离散 token 输入。 + +``` +MaskGIT logits [B, H*W, V] + │ + ▼ +将每个位置的 softmax(logits) 与 tile embedding 加权求和 + → 得到 [B, H*W, d_emb] 的软嵌入序列 + │ + ▼ +VQ-VAE 编码器(接收软嵌入输入)→ z_pred + │ + ▼ +L_consist = ||z - z_pred||² + │ + ▼ +梯度正常反传到 MaskGIT(无离散断层) +``` + +**优点**:完全可微,无需额外技巧。 +**缺点**:编码器在训练时接收的是连续软嵌入,与推理时接收离散地图存在 train/inference gap;编码器需要支持两种输入形式,增加复杂度。 + +#### 方案 A-2:Straight-Through Estimator(STE) + +利用 STE 在前向使用 argmax,在后向将梯度"直通"绕过 argmax 传给 logits: + +$$\text{forward: } \hat{y} = \text{onehot}(\arg\max(\text{logits}))$$ +$$\text{backward: } \frac{\partial \mathcal{L}}{\partial \text{logits}} \approx \frac{\partial \mathcal{L}}{\partial \text{softmax(logits)}}$$ + +即在前向正常采样离散地图,在后向把 z_pred 的梯度传回 softmax(logits),再流入 MaskGIT。 + +**优点**:编码器输入与推理时一致(离散地图),无 train/inference gap;VQ-VAE 自身已使用 STE,风格统一。 +**缺点**:STE 是梯度近似,在离散度很高(logits 尖锐)时偏差较大;效果依赖于 logits 的平滑程度。 + +#### 方案 A-3:REINFORCE / 策略梯度(作为备选) + +将 z_consist 损失视为强化学习奖励: + +$$\mathcal{L}_{RL} = -\mathbb{E}_{\hat{y} \sim \pi_\theta}[r(\hat{y})]$$ + +其中 $r(\hat{y}) = -\| z - \text{Enc}(\hat{y}) \|^2$(负的一致性误差作为奖励),$\pi_\theta$ 为 MaskGIT 的采样策略。 + +**优点**:不依赖梯度近似,理论上完全正确。 +**缺点**:方差高,训练不稳定,需要引入 baseline 估计;地图较大时(169 个位置的联合分布)方差尤为严重。**不推荐作为主方案。** + +### 推荐路径 + +**优先尝试方案 A-1(软分布近似)**,但输入编码器的不应是原始 softmax 概率——logits 的分布与真实离散地图差距较大,编码器难以理解。具体做法:对 logits 先施加温度 smoothing(降低分布的尖锐程度),再将平滑后的 softmax 概率与 tile embedding 矩阵做加权和,得到连续软嵌入序列后送入编码器。 + +若实验中发现训练与推理的 gap 仍然明显影响效果,再切换到方案 A-2(STE)。 + +### 损失权重 + +$\mathcal{L}_{consist}$ 作为辅助损失,建议权重 $\lambda_{consist}$ 初始设为 0.1,观察 z 的编码熵(codebook 使用分布)变化后再调整: + +$$\mathcal{L}_{total} = \mathcal{L}_{CE} + \beta \cdot \mathcal{L}_{commit} + \gamma \cdot \mathcal{L}_{uniform} + \lambda \cdot \mathcal{L}_{consist}$$ + +### 训练细节 + +- 一致性约束在**所有子集(A/B/C/D)**中均应用。虽然子集 B/C/D 的输入地图不完整,但模型同样需要参考 z 来生成,对所有子集施加一致性约束有助于提高模型在指定条件下的生成能力。 +- 训练时 MaskGIT 只进行单步解码,一致性约束在该单步解码结果上计算,无需多步展开,计算开销可控。 + +--- + +## 方案 B:多路分拆编码(Hierarchical z) + +### 核心思路 + +将完整地图按语义层次拆分为多个子地图,分别编码为较短的 z 向量,拼接后注入 MaskGIT。不同的 z 通道承载不同层次的语义信息,使 MaskGIT 能够从多个粒度获得控制信号。 + +### 分拆方案 + +| 通道 | 保留的图块类型 | 语义含义 | +| ------ | --------------------------------------------- | ------------------ | +| 通道 1 | 墙壁(wall)+ 楼梯/传送点(stair) | 地图骨架与层间连接 | +| 通道 2 | 墙壁 + 入口(entry)+ 怪物(mob)+ 门(door) | 关卡结构与交互元素 | +| 通道 3 | 完整地图(所有图块) | 全局风格与密度 | + +未保留的图块位置填为"空白"(tile = 0)后送入对应编码器。 + +### 架构设计 + +``` +完整地图 [H, W] + │ + ├──► 提取通道 1(墙 + 楼梯)──► Encoder_1 ──► z_1 [L_1, d_z] + │ + ├──► 提取通道 2(墙 + 入口 + 怪 + 门)──► Encoder_2 ──► z_2 [L_2, d_z] + │ + └──► 原始地图 ──► Encoder_3 ──► z_3 [L_3, d_z] + +z = Concat([z_1, z_2, z_3], dim=1) # [L_1 + L_2 + L_3, d_z] + │ + ▼ +MaskGIT Cross-Attention(z 作为 memory) +``` + +### 码字数量分配 + +三个通道的信息量递增,码字数量(z 序列长度)应与信息量成比例。以总码字数 L 为预算: + +| 通道 | 建议码字比例 | 说明 | +| ---- | ------------ | -------------------------------------- | +| z_1 | L/4 | 骨架信息精简,少量码字即可捕捉 | +| z_2 | L/4 | 关卡结构比骨架复杂,但仍是局部稀疏信息 | +| z_3 | L/2 | 全图风格需要更多表达维度 | + +若当前方案中 L = 4(即每通道 1 个码字),则分配为 1 + 1 + 2 = 4,与原方案参数量相同。考虑到数据集规模较小(约 10k 条),总码字数 L 不宜设置过大,避免 codebook 因样本不足而难以充分训练。 + +### 是否共享编码器权重 + +| 策略 | 优点 | 缺点 | +| ---------------------------- | ---------------------- | ------------------------------------ | +| 三路独立编码器 | 每个通道学习专属的表示 | 参数量增加约 3x | +| 共享底层权重(Siamese 风格) | 参数高效,底层特征共用 | 需要精心设计通道标识符区分三路输入 | +| 完全共享(同一个编码器) | 参数最少 | 三路输入差异大,同一编码器可能欠拟合 | + +**推荐**:**共享主干 Transformer + 三路独立输入头与输出头**。即三个通道共享所有 Transformer 层的权重,但各自拥有独立的输入嵌入层(input head,将图块 id 映射为特征向量)和独立的输出投影层(output head,量化前的线性投影)。这样在参数量基本不变的情况下,让不同通道在输入特征空间和量化前的表示上均有分化空间,无需额外引入通道标识符 token。 + +### 推理时的使用方式 + +推理时三个通道均可独立随机采样: + +- **完全随机生成**:三路 z 均从 codebook 随机采样; +- **骨架条件生成**:通道 1 的 z 由用户手绘的墙壁地图编码得到,通道 2、3 随机采样; +- **精确条件生成**:通道 3 的 z 由参考地图编码得到,通道 1、2 随机采样(风格迁移场景)。 + +这一机制使三个通道的解耦具有实际的用户交互价值,而不仅仅是训练侧的正则化手段。 + +### 与方案 A 的兼容性 + +方案 B 与方案 A 不冲突。在方案 B 的架构下,重建一致性约束可以单独作用于**通道 3(全图编码器)**,通道 1 和 2 的编码器由于输入为稀疏子地图,一致性约束意义相对较小。 + +--- + +## 两方案的对比 + +| 维度 | 方案 A(z 闭环) | 方案 B(多路分拆) | +| ---------------- | ---------------------------------- | ----------------------------------------- | +| 核心增益 | 强化现有 z 的控制力 | 增加 z 的语义分层,提升可解释性与可控性 | +| 实现复杂度 | 中等(需处理离散梯度问题) | 较高(需修改编码器输入管线 + 通道标识符) | +| 参数量变化 | 不变(同一编码器复用) | 小幅增加(独立头部 + 通道嵌入) | +| 训练稳定性 | 引入额外损失项,可能需要调权重 | 结构变化较大,需要较长的调试周期 | +| 推理灵活性 | 不变(仍为单路随机采样) | 提升(三路独立采样 / 条件指定) | +| 与现有方案兼容性 | 高(仅增加损失项,不改变模型结构) | 中等(需修改编码器 + dataset pipeline) | + +--- + +## 实施建议 + +### 阶段一:验证方案 A(低风险,快速验证) + +1. 在现有联合训练代码中,对子集 A 的训练步骤增加软分布近似一致性损失; +2. 监控 `L_consist`、`codebook 使用熵`、`生成多样性(pairwise distance)` 三个指标; +3. 若 z 控制力提升明显(生成结果与参考地图结构相似度上升),方案 A 单独使用即可; +4. 若提升有限,再推进方案 B。 + +### 阶段二:实施方案 B(较高收益,中等成本) + +1. 修改 `data/src/auto.ts`,在序列化时输出三个通道的地图矩阵; +2. 修改 `ginka/dataset.py`,为每个样本加载三个通道的输入; +3. 修改编码器(`ginka/vqvae/model.py`),增加通道标识符 token; +4. 修改 MaskGIT 的 cross-attention memory 拼接逻辑; +5. 复用方案 A 的一致性约束,仅作用于通道 3。 + +### 阶段三:多阶段生成(方案 C,后续计划) + +先生成墙与入口的骨架,再生成门与怪物,最后生成资源与道具。此方案依赖方案 B 中通道 1 的编码器作为阶段间信息传递的载体,适合在方案 B 验证稳定后再推进。 + +--- + +## 待细化事项 + +- [x] 方案 A:一致性损失的权重 $\lambda$ 如何随训练进度调度?→ 先使用常量(初始值 0.1),效果不佳再引入调度策略。 +- [x] 方案 A:单步解码还是多步解码后计算一致性损失?→ 训练时 MaskGIT 只进行单步解码,直接在单步结果上计算,无需多步展开。 +- [x] 方案 B:通道 2 的"墙壁"是否需要保留,还是只保留入口 + 怪 + 门?→ 保留墙壁。去掉墙壁后剩余内容趋向于散点,缺乏空间结构指导意义。 +- [x] 方案 B:三路 z 拼接后总长度是否超出 MaskGIT cross-attention 的合理 memory 长度?→ 先直接拼接,如有性能问题再评估截断或压缩策略。 +- [ ] 方案 B:通道 1/2 的编码器在用户未提供条件时,如何优雅地回退到随机采样(dropout 机制设计)?→ 暂不考虑分通道条件指定,等模型结构稳定后再设计此机制。 diff --git a/ginka/train_vq.py b/ginka/train_vq.py index a226a8d..9649412 100644 --- a/ginka/train_vq.py +++ b/ginka/train_vq.py @@ -2,6 +2,7 @@ 联合训练脚本:VQ-VAE + MaskGIT 总损失 = L_CE(MaskGIT 重建损失)+ beta * L_commit + gamma * L_entropy + + lambda * L_consist(z 一致性约束,方案 A) 验证阶段对四种子集(A/B/C/D)分别输出图片, 每条样本额外采样 N_Z_SAMPLES 个随机 z, @@ -63,6 +64,10 @@ MG_DIM_FF = 1024 MG_Z_DROPOUT = 0.1 # 训练时以此概率把 z 替换为随机噪声 MG_STRUCT_DROPOUT= 0.1 # 训练时以此概率将结构标签替换为 null(无条件占位) +# 一致性约束超参(方案 A) +CONSIST_LAMBDA = 0.1 # z 一致性损失权重 +CONSIST_TEMP = 2.0 # 计算软嵌入时对 logits 施加的温度(>1 平滑分布,降低 gap) + # 验证时对每条样本额外采样的 z 数量(0 = 只用真实 z) N_Z_SAMPLES = 3 @@ -354,7 +359,7 @@ def validate( subsets = batch["subset"] # list of str B = raw_map.shape[0] - z_q, _, vq_loss, _, _ = model_vq(raw_map) + z_q, _, _, vq_loss, _, _ = model_vq(raw_map) struct_cond_b = batch["struct_cond"].to(device) # [B, 4] logits = model_mg(masked_map, z_q, struct_cond=struct_cond_b) mask = (masked_map == MASK_TOKEN) @@ -589,6 +594,7 @@ def train(): vq_loss_total = 0.0 commit_total = 0.0 entropy_total = 0.0 + consist_total = 0.0 subset_stats = {'A': 0, 'B': 0, 'C': 0, 'D': 0} for batch in tqdm(dataloader_train, leave=False, @@ -601,8 +607,8 @@ def train(): subset_stats[s] = subset_stats.get(s, 0) + 1 # ---- 前向传播 ---- - # 1. VQ-VAE 编码真实地图 → z_q - z_q, _, vq_loss, commit_loss, entropy_loss = model_vq(raw_map) # z_q: [B, L, d_z] + # 1. VQ-VAE 编码真实地图 → z_q, z_e + z_q, z_e, _, vq_loss, commit_loss, entropy_loss = model_vq(raw_map) # z_q/z_e: [B, L, d_z] # 2. MaskGIT 以掩码地图 + z + 结构标签预测原始 tile struct_cond = batch["struct_cond"].to(device) # [B, 4] @@ -616,8 +622,19 @@ def train(): ) masked_ce = (ce_loss * mask).sum() / (mask.sum() + 1e-6) - # 4. 联合损失 - loss = masked_ce + vq_loss + # 4. z 一致性约束(方案 A):将 MaskGIT 的 logits 经温度平滑后 + # 与 VQ 编码器的 tile embedding 做加权求和,得到软嵌入序列, + # 再送入编码器得到 z_pred_e,约束其与真实 z_e 对齐。 + # 梯度从 z_pred_e 回传到 MaskGIT 的 logits(以及 VQ encoder 的权重); + # z_e 作为 detach 后的监督目标,不产生梯度。 + soft_probs = F.softmax(logits / CONSIST_TEMP, dim=-1) # [B, H*W, V] + tile_emb = model_vq.tile_embedding.weight # [V, d_model] + soft_emb = soft_probs @ tile_emb # [B, H*W, d_model] + z_pred_e = model_vq.encode_soft(soft_emb) # [B, L, d_z] + consist_loss = F.mse_loss(z_pred_e, z_e.detach()) + + # 5. 联合损失 + loss = masked_ce + vq_loss + CONSIST_LAMBDA * consist_loss optimizer.zero_grad() loss.backward() @@ -629,6 +646,7 @@ def train(): vq_loss_total += vq_loss.detach().item() commit_total += commit_loss.detach().item() entropy_total += entropy_loss.detach().item() + consist_total += consist_loss.detach().item() scheduler.step() @@ -640,7 +658,8 @@ def train(): f"CE {ce_total/n:.5f} " f"VQ {vq_loss_total/n:.5f} " f"Commit {commit_total/n:.5f} " - f"Entropy {entropy_total/n:.5f} | " + f"Entropy {entropy_total/n:.5f} " + f"Consist {consist_total/n:.5f} | " f"LR {scheduler.get_last_lr()[0]:.6f} | " f"Subsets {subset_stats}" ) diff --git a/ginka/vqvae/model.py b/ginka/vqvae/model.py index 617a5f9..add46f7 100644 --- a/ginka/vqvae/model.py +++ b/ginka/vqvae/model.py @@ -115,7 +115,32 @@ class GinkaVQVAE(nn.Module): z_e = self.proj(x[:, :self.L]) # [B, L, d_z] return z_e - def forward(self, map: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def encode_soft(self, soft_emb: torch.Tensor) -> torch.Tensor: + """ + 将软嵌入序列编码为量化前的连续向量序列(用于一致性约束)。 + + 与 encode() 的区别:输入是已经过 softmax 加权求和得到的连续嵌入矩阵 + [B, H*W, d_model],而非整数 tile ID。梯度可完整回传到调用方的 logits。 + + Args: + soft_emb: [B, H*W, d_model] softmax 加权 tile 嵌入(已在 d_model 空间) + + Returns: + z_e: [B, L, d_z] 量化前的编码向量 + """ + B = soft_emb.shape[0] + + x = soft_emb + self.pos_embedding # [B, H*W, d_model] + + summary = self.summary_tokens.expand(B, -1, -1) # [B, L, d_model] + x = torch.cat([summary, x], dim=1) # [B, L+H*W, d_model] + + x = self.transformer(x) # [B, L+H*W, d_model] + + z_e = self.proj(x[:, :self.L]) # [B, L, d_z] + return z_e + + def forward(self, map: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ 完整前向传播:编码 → 量化 → 计算损失。 @@ -123,15 +148,18 @@ class GinkaVQVAE(nn.Module): map: [B, H*W] 整数 tile ID(训练时传入完整真实地图) Returns: - z_q: [B, L, d_z] 量化后的 z(含直通梯度),供 MaskGIT 使用 - indices: [B, L] 每个位置对应的码字索引 - vq_loss: scalar VQ 总损失 = beta * commit_loss + gamma * entropy_loss + z_q: [B, L, d_z] 量化后的 z(含直通梯度),供 MaskGIT 使用 + z_e: [B, L, d_z] 量化前的连续编码向量,供一致性约束使用 + indices: [B, L] 每个位置对应的码字索引 + vq_loss: scalar VQ 总损失 = beta * commit_loss + gamma * entropy_loss + commit_loss: scalar + entropy_loss: scalar """ z_e = self.encode(map) z_q, indices, commit_loss, entropy_loss = self.vq(z_e) vq_loss = self.beta * commit_loss + self.gamma * entropy_loss - return z_q, indices, vq_loss, commit_loss, entropy_loss + return z_q, z_e, indices, vq_loss, commit_loss, entropy_loss def sample(self, B: int, device: torch.device) -> torch.Tensor: """ @@ -175,9 +203,10 @@ if __name__ == "__main__": # 前向传播测试 map_input = torch.randint(0, 15, (4, 13 * 13)).to(device) # [B=4, 169] - z_q, indices, vq_loss = model(map_input) + z_q, z_e, indices, vq_loss, commit_loss, entropy_loss = model(map_input) print(f"\nz_q shape: {z_q.shape}") # [4, 2, 64] + print(f"z_e shape: {z_e.shape}") # [4, 2, 64] print(f"indices shape:{indices.shape}") # [4, 2] print(f"vq_loss: {vq_loss.item():.4f}")