mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-17 06:51:11 +08:00
feat: 提高 VQ 解码头容量
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
294e214431
commit
f22943820c
@ -50,8 +50,10 @@ VQ_DIM_FF = 512
|
||||
VQ_BETA = 0.5
|
||||
VQ_GAMMA = 0.0
|
||||
|
||||
# 解码头超参
|
||||
DH_NHEAD = 8 # Cross-Attention 头数(VQ_D_Z=128 可被 8 整除)
|
||||
# 解码头超参(与编码器对称:同等层数和 FFN 宽度)
|
||||
DH_NHEAD = 8 # Cross-Attention 头数(VQ_D_Z=128 可被 8 整除)
|
||||
DH_DIM_FF = 512 # FFN 隐层维度(与编码器 VQ_DIM_FF 一致)
|
||||
DH_LAYERS = 4 # 解码层数(与编码器 VQ_LAYERS 一致)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 设备
|
||||
@ -185,6 +187,8 @@ def train():
|
||||
d_z=VQ_D_Z,
|
||||
map_size=MAP_SIZE,
|
||||
nhead=DH_NHEAD,
|
||||
dim_ff=DH_DIM_FF,
|
||||
num_layers=DH_LAYERS,
|
||||
).to(device)
|
||||
|
||||
vq_params = sum(p.numel() for p in model_vq.parameters())
|
||||
|
||||
@ -4,18 +4,42 @@ from .quantize import VectorQuantizer
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class _DecodeLayer(nn.Module):
|
||||
"""单个解码层:Pre-LN Cross-Attention + Pre-LN FFN。"""
|
||||
|
||||
def __init__(self, d_z: int, nhead: int, dim_ff: int):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(d_z)
|
||||
self.cross_attn = nn.MultiheadAttention(
|
||||
embed_dim=d_z,
|
||||
num_heads=nhead,
|
||||
batch_first=True,
|
||||
dropout=0.0,
|
||||
)
|
||||
self.norm2 = nn.LayerNorm(d_z)
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(d_z, dim_ff),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim_ff, d_z),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, z_q: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.cross_attn(self.norm1(x), z_q, z_q)[0] # Pre-LN cross-attn
|
||||
x = x + self.ffn(self.norm2(x)) # Pre-LN FFN
|
||||
return x
|
||||
|
||||
|
||||
class VQDecodeHead(nn.Module):
|
||||
"""
|
||||
VQ-VAE 预训练用轻量解码头(Cross-Attention 架构)。
|
||||
VQ-VAE 预训练用解码头(堆叠 Cross-Attention + FFN,Pre-LN 风格)。
|
||||
|
||||
将 z_q [B, L, d_z] 通过 Cross-Attention 还原为地图 logits [B, H*W, num_classes]。
|
||||
将 z_q [B, L, d_z] 通过多层 Cross-Attention 解码为地图 logits [B, H*W, num_classes]。
|
||||
预训练结束后此模块被丢弃,不影响联合训练路径。
|
||||
|
||||
架构:
|
||||
可学习位置查询 [B, H*W, d_z]
|
||||
→ Cross-Attention (query=位置查询, key/value=z_q)
|
||||
→ LayerNorm
|
||||
→ 线性分类头 → logits [B, H*W, num_classes]
|
||||
架构(每层):
|
||||
Pre-LN Cross-Attention(query=可学习位置查询, key/value=z_q)
|
||||
Pre-LN FFN
|
||||
× num_layers 层 → LayerNorm → 线性分类头
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -23,7 +47,9 @@ class VQDecodeHead(nn.Module):
|
||||
num_classes: int,
|
||||
d_z: int,
|
||||
map_size: int,
|
||||
nhead: int = 4,
|
||||
nhead: int = 8,
|
||||
dim_ff: int = 512,
|
||||
num_layers: int = 4,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -31,22 +57,20 @@ class VQDecodeHead(nn.Module):
|
||||
d_z: z 向量维度(须与 GinkaVQVAE 的 d_z 一致)
|
||||
map_size: 地图 token 总数(H * W)
|
||||
nhead: Cross-Attention 的注意力头数(d_z 须能被 nhead 整除)
|
||||
dim_ff: FFN 隐层维度
|
||||
num_layers: 解码层数(建议与编码器 num_layers 相同)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# 每个格子一个可学习位置查询
|
||||
self.pos_queries = nn.Parameter(torch.randn(1, map_size, d_z) * 0.02)
|
||||
|
||||
# Cross-Attention:query=位置查询,key/value=z_q
|
||||
self.cross_attn = nn.MultiheadAttention(
|
||||
embed_dim=d_z,
|
||||
num_heads=nhead,
|
||||
batch_first=True,
|
||||
dropout=0.0,
|
||||
)
|
||||
self.norm = nn.LayerNorm(d_z)
|
||||
# 堆叠多层解码块
|
||||
self.layers = nn.ModuleList([
|
||||
_DecodeLayer(d_z, nhead, dim_ff) for _ in range(num_layers)
|
||||
])
|
||||
|
||||
# 最终分类头
|
||||
self.norm_out = nn.LayerNorm(d_z)
|
||||
self.classifier = nn.Linear(d_z, num_classes)
|
||||
|
||||
def forward(self, z_q: torch.Tensor) -> torch.Tensor:
|
||||
@ -58,10 +82,11 @@ class VQDecodeHead(nn.Module):
|
||||
logits: [B, map_size, num_classes]
|
||||
"""
|
||||
B = z_q.shape[0]
|
||||
q = self.pos_queries.expand(B, -1, -1) # [B, map_size, d_z]
|
||||
out, _ = self.cross_attn(q, z_q, z_q) # [B, map_size, d_z]
|
||||
out = self.norm(out)
|
||||
return self.classifier(out) # [B, map_size, num_classes]
|
||||
x = self.pos_queries.expand(B, -1, -1) # [B, map_size, d_z]
|
||||
for layer in self.layers:
|
||||
x = layer(x, z_q)
|
||||
x = self.norm_out(x)
|
||||
return self.classifier(x) # [B, map_size, num_classes]
|
||||
|
||||
|
||||
class GinkaVQVAE(nn.Module):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user