From f22943820cb3a8aec659b0d75748862ad5a166e3 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Tue, 28 Apr 2026 16:50:15 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=8F=90=E9=AB=98=20VQ=20=E8=A7=A3?= =?UTF-8?q?=E7=A0=81=E5=A4=B4=E5=AE=B9=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Copilot --- ginka/train_pretrain.py | 8 +++-- ginka/vqvae/model.py | 67 ++++++++++++++++++++++++++++------------- 2 files changed, 52 insertions(+), 23 deletions(-) diff --git a/ginka/train_pretrain.py b/ginka/train_pretrain.py index e8c157a..f27e903 100644 --- a/ginka/train_pretrain.py +++ b/ginka/train_pretrain.py @@ -50,8 +50,10 @@ VQ_DIM_FF = 512 VQ_BETA = 0.5 VQ_GAMMA = 0.0 -# 解码头超参 -DH_NHEAD = 8 # Cross-Attention 头数(VQ_D_Z=128 可被 8 整除) +# 解码头超参(与编码器对称:同等层数和 FFN 宽度) +DH_NHEAD = 8 # Cross-Attention 头数(VQ_D_Z=128 可被 8 整除) +DH_DIM_FF = 512 # FFN 隐层维度(与编码器 VQ_DIM_FF 一致) +DH_LAYERS = 4 # 解码层数(与编码器 VQ_LAYERS 一致) # --------------------------------------------------------------------------- # 设备 @@ -185,6 +187,8 @@ def train(): d_z=VQ_D_Z, map_size=MAP_SIZE, nhead=DH_NHEAD, + dim_ff=DH_DIM_FF, + num_layers=DH_LAYERS, ).to(device) vq_params = sum(p.numel() for p in model_vq.parameters()) diff --git a/ginka/vqvae/model.py b/ginka/vqvae/model.py index 89cc4c0..a1ab105 100644 --- a/ginka/vqvae/model.py +++ b/ginka/vqvae/model.py @@ -4,18 +4,42 @@ from .quantize import VectorQuantizer from typing import Tuple +class _DecodeLayer(nn.Module): + """单个解码层:Pre-LN Cross-Attention + Pre-LN FFN。""" + + def __init__(self, d_z: int, nhead: int, dim_ff: int): + super().__init__() + self.norm1 = nn.LayerNorm(d_z) + self.cross_attn = nn.MultiheadAttention( + embed_dim=d_z, + num_heads=nhead, + batch_first=True, + dropout=0.0, + ) + self.norm2 = nn.LayerNorm(d_z) + self.ffn = nn.Sequential( + nn.Linear(d_z, dim_ff), + nn.GELU(), + nn.Linear(dim_ff, d_z), + ) + + def forward(self, x: torch.Tensor, z_q: torch.Tensor) -> torch.Tensor: + x = x + self.cross_attn(self.norm1(x), z_q, z_q)[0] # Pre-LN cross-attn + x = x + self.ffn(self.norm2(x)) # Pre-LN FFN + return x + + class VQDecodeHead(nn.Module): """ - VQ-VAE 预训练用轻量解码头(Cross-Attention 架构)。 + VQ-VAE 预训练用解码头(堆叠 Cross-Attention + FFN,Pre-LN 风格)。 - 将 z_q [B, L, d_z] 通过 Cross-Attention 还原为地图 logits [B, H*W, num_classes]。 + 将 z_q [B, L, d_z] 通过多层 Cross-Attention 解码为地图 logits [B, H*W, num_classes]。 预训练结束后此模块被丢弃,不影响联合训练路径。 - 架构: - 可学习位置查询 [B, H*W, d_z] - → Cross-Attention (query=位置查询, key/value=z_q) - → LayerNorm - → 线性分类头 → logits [B, H*W, num_classes] + 架构(每层): + Pre-LN Cross-Attention(query=可学习位置查询, key/value=z_q) + Pre-LN FFN + × num_layers 层 → LayerNorm → 线性分类头 """ def __init__( @@ -23,7 +47,9 @@ class VQDecodeHead(nn.Module): num_classes: int, d_z: int, map_size: int, - nhead: int = 4, + nhead: int = 8, + dim_ff: int = 512, + num_layers: int = 4, ): """ Args: @@ -31,22 +57,20 @@ class VQDecodeHead(nn.Module): d_z: z 向量维度(须与 GinkaVQVAE 的 d_z 一致) map_size: 地图 token 总数(H * W) nhead: Cross-Attention 的注意力头数(d_z 须能被 nhead 整除) + dim_ff: FFN 隐层维度 + num_layers: 解码层数(建议与编码器 num_layers 相同) """ super().__init__() # 每个格子一个可学习位置查询 self.pos_queries = nn.Parameter(torch.randn(1, map_size, d_z) * 0.02) - # Cross-Attention:query=位置查询,key/value=z_q - self.cross_attn = nn.MultiheadAttention( - embed_dim=d_z, - num_heads=nhead, - batch_first=True, - dropout=0.0, - ) - self.norm = nn.LayerNorm(d_z) + # 堆叠多层解码块 + self.layers = nn.ModuleList([ + _DecodeLayer(d_z, nhead, dim_ff) for _ in range(num_layers) + ]) - # 最终分类头 + self.norm_out = nn.LayerNorm(d_z) self.classifier = nn.Linear(d_z, num_classes) def forward(self, z_q: torch.Tensor) -> torch.Tensor: @@ -58,10 +82,11 @@ class VQDecodeHead(nn.Module): logits: [B, map_size, num_classes] """ B = z_q.shape[0] - q = self.pos_queries.expand(B, -1, -1) # [B, map_size, d_z] - out, _ = self.cross_attn(q, z_q, z_q) # [B, map_size, d_z] - out = self.norm(out) - return self.classifier(out) # [B, map_size, num_classes] + x = self.pos_queries.expand(B, -1, -1) # [B, map_size, d_z] + for layer in self.layers: + x = layer(x, z_q) + x = self.norm_out(x) + return self.classifier(x) # [B, map_size, num_classes] class GinkaVQVAE(nn.Module):