feat: 提高 VQ 解码头容量

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
unanmed 2026-04-28 16:50:15 +08:00
parent 294e214431
commit f22943820c
2 changed files with 52 additions and 23 deletions

View File

@ -50,8 +50,10 @@ VQ_DIM_FF = 512
VQ_BETA = 0.5 VQ_BETA = 0.5
VQ_GAMMA = 0.0 VQ_GAMMA = 0.0
# 解码头超参 # 解码头超参(与编码器对称:同等层数和 FFN 宽度)
DH_NHEAD = 8 # Cross-Attention 头数VQ_D_Z=128 可被 8 整除) DH_NHEAD = 8 # Cross-Attention 头数VQ_D_Z=128 可被 8 整除)
DH_DIM_FF = 512 # FFN 隐层维度(与编码器 VQ_DIM_FF 一致)
DH_LAYERS = 4 # 解码层数(与编码器 VQ_LAYERS 一致)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# 设备 # 设备
@ -185,6 +187,8 @@ def train():
d_z=VQ_D_Z, d_z=VQ_D_Z,
map_size=MAP_SIZE, map_size=MAP_SIZE,
nhead=DH_NHEAD, nhead=DH_NHEAD,
dim_ff=DH_DIM_FF,
num_layers=DH_LAYERS,
).to(device) ).to(device)
vq_params = sum(p.numel() for p in model_vq.parameters()) vq_params = sum(p.numel() for p in model_vq.parameters())

View File

@ -4,18 +4,42 @@ from .quantize import VectorQuantizer
from typing import Tuple from typing import Tuple
class _DecodeLayer(nn.Module):
"""单个解码层Pre-LN Cross-Attention + Pre-LN FFN。"""
def __init__(self, d_z: int, nhead: int, dim_ff: int):
super().__init__()
self.norm1 = nn.LayerNorm(d_z)
self.cross_attn = nn.MultiheadAttention(
embed_dim=d_z,
num_heads=nhead,
batch_first=True,
dropout=0.0,
)
self.norm2 = nn.LayerNorm(d_z)
self.ffn = nn.Sequential(
nn.Linear(d_z, dim_ff),
nn.GELU(),
nn.Linear(dim_ff, d_z),
)
def forward(self, x: torch.Tensor, z_q: torch.Tensor) -> torch.Tensor:
x = x + self.cross_attn(self.norm1(x), z_q, z_q)[0] # Pre-LN cross-attn
x = x + self.ffn(self.norm2(x)) # Pre-LN FFN
return x
class VQDecodeHead(nn.Module): class VQDecodeHead(nn.Module):
""" """
VQ-VAE 预训练用轻量解码头Cross-Attention 架构 VQ-VAE 预训练用解码头堆叠 Cross-Attention + FFNPre-LN 风格
z_q [B, L, d_z] 通过 Cross-Attention 还原为地图 logits [B, H*W, num_classes] z_q [B, L, d_z] 通过多层 Cross-Attention 解码为地图 logits [B, H*W, num_classes]
预训练结束后此模块被丢弃不影响联合训练路径 预训练结束后此模块被丢弃不影响联合训练路径
架构 架构每层
可学习位置查询 [B, H*W, d_z] Pre-LN Cross-Attentionquery=可学习位置查询, key/value=z_q
Cross-Attention (query=位置查询, key/value=z_q) Pre-LN FFN
LayerNorm × num_layers LayerNorm 线性分类头
线性分类头 logits [B, H*W, num_classes]
""" """
def __init__( def __init__(
@ -23,7 +47,9 @@ class VQDecodeHead(nn.Module):
num_classes: int, num_classes: int,
d_z: int, d_z: int,
map_size: int, map_size: int,
nhead: int = 4, nhead: int = 8,
dim_ff: int = 512,
num_layers: int = 4,
): ):
""" """
Args: Args:
@ -31,22 +57,20 @@ class VQDecodeHead(nn.Module):
d_z: z 向量维度须与 GinkaVQVAE d_z 一致 d_z: z 向量维度须与 GinkaVQVAE d_z 一致
map_size: 地图 token 总数H * W map_size: 地图 token 总数H * W
nhead: Cross-Attention 的注意力头数d_z 须能被 nhead 整除 nhead: Cross-Attention 的注意力头数d_z 须能被 nhead 整除
dim_ff: FFN 隐层维度
num_layers: 解码层数建议与编码器 num_layers 相同
""" """
super().__init__() super().__init__()
# 每个格子一个可学习位置查询 # 每个格子一个可学习位置查询
self.pos_queries = nn.Parameter(torch.randn(1, map_size, d_z) * 0.02) self.pos_queries = nn.Parameter(torch.randn(1, map_size, d_z) * 0.02)
# Cross-Attentionquery=位置查询key/value=z_q # 堆叠多层解码块
self.cross_attn = nn.MultiheadAttention( self.layers = nn.ModuleList([
embed_dim=d_z, _DecodeLayer(d_z, nhead, dim_ff) for _ in range(num_layers)
num_heads=nhead, ])
batch_first=True,
dropout=0.0,
)
self.norm = nn.LayerNorm(d_z)
# 最终分类头 self.norm_out = nn.LayerNorm(d_z)
self.classifier = nn.Linear(d_z, num_classes) self.classifier = nn.Linear(d_z, num_classes)
def forward(self, z_q: torch.Tensor) -> torch.Tensor: def forward(self, z_q: torch.Tensor) -> torch.Tensor:
@ -58,10 +82,11 @@ class VQDecodeHead(nn.Module):
logits: [B, map_size, num_classes] logits: [B, map_size, num_classes]
""" """
B = z_q.shape[0] B = z_q.shape[0]
q = self.pos_queries.expand(B, -1, -1) # [B, map_size, d_z] x = self.pos_queries.expand(B, -1, -1) # [B, map_size, d_z]
out, _ = self.cross_attn(q, z_q, z_q) # [B, map_size, d_z] for layer in self.layers:
out = self.norm(out) x = layer(x, z_q)
return self.classifier(out) # [B, map_size, num_classes] x = self.norm_out(x)
return self.classifier(x) # [B, map_size, num_classes]
class GinkaVQVAE(nn.Module): class GinkaVQVAE(nn.Module):