feat: 提高 VQ 解码头容量

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
unanmed 2026-04-28 16:50:15 +08:00
parent 294e214431
commit f22943820c
2 changed files with 52 additions and 23 deletions

View File

@ -50,8 +50,10 @@ VQ_DIM_FF = 512
VQ_BETA = 0.5
VQ_GAMMA = 0.0
# 解码头超参
DH_NHEAD = 8 # Cross-Attention 头数VQ_D_Z=128 可被 8 整除)
# 解码头超参(与编码器对称:同等层数和 FFN 宽度)
DH_NHEAD = 8 # Cross-Attention 头数VQ_D_Z=128 可被 8 整除)
DH_DIM_FF = 512 # FFN 隐层维度(与编码器 VQ_DIM_FF 一致)
DH_LAYERS = 4 # 解码层数(与编码器 VQ_LAYERS 一致)
# ---------------------------------------------------------------------------
# 设备
@ -185,6 +187,8 @@ def train():
d_z=VQ_D_Z,
map_size=MAP_SIZE,
nhead=DH_NHEAD,
dim_ff=DH_DIM_FF,
num_layers=DH_LAYERS,
).to(device)
vq_params = sum(p.numel() for p in model_vq.parameters())

View File

@ -4,18 +4,42 @@ from .quantize import VectorQuantizer
from typing import Tuple
class _DecodeLayer(nn.Module):
"""单个解码层Pre-LN Cross-Attention + Pre-LN FFN。"""
def __init__(self, d_z: int, nhead: int, dim_ff: int):
super().__init__()
self.norm1 = nn.LayerNorm(d_z)
self.cross_attn = nn.MultiheadAttention(
embed_dim=d_z,
num_heads=nhead,
batch_first=True,
dropout=0.0,
)
self.norm2 = nn.LayerNorm(d_z)
self.ffn = nn.Sequential(
nn.Linear(d_z, dim_ff),
nn.GELU(),
nn.Linear(dim_ff, d_z),
)
def forward(self, x: torch.Tensor, z_q: torch.Tensor) -> torch.Tensor:
x = x + self.cross_attn(self.norm1(x), z_q, z_q)[0] # Pre-LN cross-attn
x = x + self.ffn(self.norm2(x)) # Pre-LN FFN
return x
class VQDecodeHead(nn.Module):
"""
VQ-VAE 预训练用轻量解码头Cross-Attention 架构
VQ-VAE 预训练用解码头堆叠 Cross-Attention + FFNPre-LN 风格
z_q [B, L, d_z] 通过 Cross-Attention 还原为地图 logits [B, H*W, num_classes]
z_q [B, L, d_z] 通过多层 Cross-Attention 解码为地图 logits [B, H*W, num_classes]
预训练结束后此模块被丢弃不影响联合训练路径
架构
可学习位置查询 [B, H*W, d_z]
Cross-Attention (query=位置查询, key/value=z_q)
LayerNorm
线性分类头 logits [B, H*W, num_classes]
架构每层
Pre-LN Cross-Attentionquery=可学习位置查询, key/value=z_q
Pre-LN FFN
× num_layers LayerNorm 线性分类头
"""
def __init__(
@ -23,7 +47,9 @@ class VQDecodeHead(nn.Module):
num_classes: int,
d_z: int,
map_size: int,
nhead: int = 4,
nhead: int = 8,
dim_ff: int = 512,
num_layers: int = 4,
):
"""
Args:
@ -31,22 +57,20 @@ class VQDecodeHead(nn.Module):
d_z: z 向量维度须与 GinkaVQVAE d_z 一致
map_size: 地图 token 总数H * W
nhead: Cross-Attention 的注意力头数d_z 须能被 nhead 整除
dim_ff: FFN 隐层维度
num_layers: 解码层数建议与编码器 num_layers 相同
"""
super().__init__()
# 每个格子一个可学习位置查询
self.pos_queries = nn.Parameter(torch.randn(1, map_size, d_z) * 0.02)
# Cross-Attentionquery=位置查询key/value=z_q
self.cross_attn = nn.MultiheadAttention(
embed_dim=d_z,
num_heads=nhead,
batch_first=True,
dropout=0.0,
)
self.norm = nn.LayerNorm(d_z)
# 堆叠多层解码块
self.layers = nn.ModuleList([
_DecodeLayer(d_z, nhead, dim_ff) for _ in range(num_layers)
])
# 最终分类头
self.norm_out = nn.LayerNorm(d_z)
self.classifier = nn.Linear(d_z, num_classes)
def forward(self, z_q: torch.Tensor) -> torch.Tensor:
@ -58,10 +82,11 @@ class VQDecodeHead(nn.Module):
logits: [B, map_size, num_classes]
"""
B = z_q.shape[0]
q = self.pos_queries.expand(B, -1, -1) # [B, map_size, d_z]
out, _ = self.cross_attn(q, z_q, z_q) # [B, map_size, d_z]
out = self.norm(out)
return self.classifier(out) # [B, map_size, num_classes]
x = self.pos_queries.expand(B, -1, -1) # [B, map_size, d_z]
for layer in self.layers:
x = layer(x, z_q)
x = self.norm_out(x)
return self.classifier(x) # [B, map_size, num_classes]
class GinkaVQVAE(nn.Module):