mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-22 02:44:51 +08:00
feat: 提高 VQ 解码头容量
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
294e214431
commit
f22943820c
@ -50,8 +50,10 @@ VQ_DIM_FF = 512
|
|||||||
VQ_BETA = 0.5
|
VQ_BETA = 0.5
|
||||||
VQ_GAMMA = 0.0
|
VQ_GAMMA = 0.0
|
||||||
|
|
||||||
# 解码头超参
|
# 解码头超参(与编码器对称:同等层数和 FFN 宽度)
|
||||||
DH_NHEAD = 8 # Cross-Attention 头数(VQ_D_Z=128 可被 8 整除)
|
DH_NHEAD = 8 # Cross-Attention 头数(VQ_D_Z=128 可被 8 整除)
|
||||||
|
DH_DIM_FF = 512 # FFN 隐层维度(与编码器 VQ_DIM_FF 一致)
|
||||||
|
DH_LAYERS = 4 # 解码层数(与编码器 VQ_LAYERS 一致)
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# 设备
|
# 设备
|
||||||
@ -185,6 +187,8 @@ def train():
|
|||||||
d_z=VQ_D_Z,
|
d_z=VQ_D_Z,
|
||||||
map_size=MAP_SIZE,
|
map_size=MAP_SIZE,
|
||||||
nhead=DH_NHEAD,
|
nhead=DH_NHEAD,
|
||||||
|
dim_ff=DH_DIM_FF,
|
||||||
|
num_layers=DH_LAYERS,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
vq_params = sum(p.numel() for p in model_vq.parameters())
|
vq_params = sum(p.numel() for p in model_vq.parameters())
|
||||||
|
|||||||
@ -4,18 +4,42 @@ from .quantize import VectorQuantizer
|
|||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class _DecodeLayer(nn.Module):
|
||||||
|
"""单个解码层:Pre-LN Cross-Attention + Pre-LN FFN。"""
|
||||||
|
|
||||||
|
def __init__(self, d_z: int, nhead: int, dim_ff: int):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = nn.LayerNorm(d_z)
|
||||||
|
self.cross_attn = nn.MultiheadAttention(
|
||||||
|
embed_dim=d_z,
|
||||||
|
num_heads=nhead,
|
||||||
|
batch_first=True,
|
||||||
|
dropout=0.0,
|
||||||
|
)
|
||||||
|
self.norm2 = nn.LayerNorm(d_z)
|
||||||
|
self.ffn = nn.Sequential(
|
||||||
|
nn.Linear(d_z, dim_ff),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(dim_ff, d_z),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, z_q: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = x + self.cross_attn(self.norm1(x), z_q, z_q)[0] # Pre-LN cross-attn
|
||||||
|
x = x + self.ffn(self.norm2(x)) # Pre-LN FFN
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class VQDecodeHead(nn.Module):
|
class VQDecodeHead(nn.Module):
|
||||||
"""
|
"""
|
||||||
VQ-VAE 预训练用轻量解码头(Cross-Attention 架构)。
|
VQ-VAE 预训练用解码头(堆叠 Cross-Attention + FFN,Pre-LN 风格)。
|
||||||
|
|
||||||
将 z_q [B, L, d_z] 通过 Cross-Attention 还原为地图 logits [B, H*W, num_classes]。
|
将 z_q [B, L, d_z] 通过多层 Cross-Attention 解码为地图 logits [B, H*W, num_classes]。
|
||||||
预训练结束后此模块被丢弃,不影响联合训练路径。
|
预训练结束后此模块被丢弃,不影响联合训练路径。
|
||||||
|
|
||||||
架构:
|
架构(每层):
|
||||||
可学习位置查询 [B, H*W, d_z]
|
Pre-LN Cross-Attention(query=可学习位置查询, key/value=z_q)
|
||||||
→ Cross-Attention (query=位置查询, key/value=z_q)
|
Pre-LN FFN
|
||||||
→ LayerNorm
|
× num_layers 层 → LayerNorm → 线性分类头
|
||||||
→ 线性分类头 → logits [B, H*W, num_classes]
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -23,7 +47,9 @@ class VQDecodeHead(nn.Module):
|
|||||||
num_classes: int,
|
num_classes: int,
|
||||||
d_z: int,
|
d_z: int,
|
||||||
map_size: int,
|
map_size: int,
|
||||||
nhead: int = 4,
|
nhead: int = 8,
|
||||||
|
dim_ff: int = 512,
|
||||||
|
num_layers: int = 4,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -31,22 +57,20 @@ class VQDecodeHead(nn.Module):
|
|||||||
d_z: z 向量维度(须与 GinkaVQVAE 的 d_z 一致)
|
d_z: z 向量维度(须与 GinkaVQVAE 的 d_z 一致)
|
||||||
map_size: 地图 token 总数(H * W)
|
map_size: 地图 token 总数(H * W)
|
||||||
nhead: Cross-Attention 的注意力头数(d_z 须能被 nhead 整除)
|
nhead: Cross-Attention 的注意力头数(d_z 须能被 nhead 整除)
|
||||||
|
dim_ff: FFN 隐层维度
|
||||||
|
num_layers: 解码层数(建议与编码器 num_layers 相同)
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# 每个格子一个可学习位置查询
|
# 每个格子一个可学习位置查询
|
||||||
self.pos_queries = nn.Parameter(torch.randn(1, map_size, d_z) * 0.02)
|
self.pos_queries = nn.Parameter(torch.randn(1, map_size, d_z) * 0.02)
|
||||||
|
|
||||||
# Cross-Attention:query=位置查询,key/value=z_q
|
# 堆叠多层解码块
|
||||||
self.cross_attn = nn.MultiheadAttention(
|
self.layers = nn.ModuleList([
|
||||||
embed_dim=d_z,
|
_DecodeLayer(d_z, nhead, dim_ff) for _ in range(num_layers)
|
||||||
num_heads=nhead,
|
])
|
||||||
batch_first=True,
|
|
||||||
dropout=0.0,
|
|
||||||
)
|
|
||||||
self.norm = nn.LayerNorm(d_z)
|
|
||||||
|
|
||||||
# 最终分类头
|
self.norm_out = nn.LayerNorm(d_z)
|
||||||
self.classifier = nn.Linear(d_z, num_classes)
|
self.classifier = nn.Linear(d_z, num_classes)
|
||||||
|
|
||||||
def forward(self, z_q: torch.Tensor) -> torch.Tensor:
|
def forward(self, z_q: torch.Tensor) -> torch.Tensor:
|
||||||
@ -58,10 +82,11 @@ class VQDecodeHead(nn.Module):
|
|||||||
logits: [B, map_size, num_classes]
|
logits: [B, map_size, num_classes]
|
||||||
"""
|
"""
|
||||||
B = z_q.shape[0]
|
B = z_q.shape[0]
|
||||||
q = self.pos_queries.expand(B, -1, -1) # [B, map_size, d_z]
|
x = self.pos_queries.expand(B, -1, -1) # [B, map_size, d_z]
|
||||||
out, _ = self.cross_attn(q, z_q, z_q) # [B, map_size, d_z]
|
for layer in self.layers:
|
||||||
out = self.norm(out)
|
x = layer(x, z_q)
|
||||||
return self.classifier(out) # [B, map_size, num_classes]
|
x = self.norm_out(x)
|
||||||
|
return self.classifier(x) # [B, map_size, num_classes]
|
||||||
|
|
||||||
|
|
||||||
class GinkaVQVAE(nn.Module):
|
class GinkaVQVAE(nn.Module):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user