ginka-generator/ginka/vqvae/model.py
unanmed a69403d6bf feat: vq 预训练
Co-authored-by: Copilot <copilot@github.com>
2026-04-28 16:31:53 +08:00

276 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
from .quantize import VectorQuantizer
from typing import Tuple
class VQDecodeHead(nn.Module):
"""
VQ-VAE 预训练用轻量解码头Cross-Attention 架构)。
将 z_q [B, L, d_z] 通过 Cross-Attention 还原为地图 logits [B, H*W, num_classes]。
预训练结束后此模块被丢弃,不影响联合训练路径。
架构:
可学习位置查询 [B, H*W, d_z]
→ Cross-Attention (query=位置查询, key/value=z_q)
→ LayerNorm
→ 线性分类头 → logits [B, H*W, num_classes]
"""
def __init__(
self,
num_classes: int,
d_z: int,
map_size: int,
nhead: int = 4,
):
"""
Args:
num_classes: tile 类别数
d_z: z 向量维度(须与 GinkaVQVAE 的 d_z 一致)
map_size: 地图 token 总数H * W
nhead: Cross-Attention 的注意力头数d_z 须能被 nhead 整除)
"""
super().__init__()
# 每个格子一个可学习位置查询
self.pos_queries = nn.Parameter(torch.randn(1, map_size, d_z) * 0.02)
# Cross-Attentionquery=位置查询key/value=z_q
self.cross_attn = nn.MultiheadAttention(
embed_dim=d_z,
num_heads=nhead,
batch_first=True,
dropout=0.0,
)
self.norm = nn.LayerNorm(d_z)
# 最终分类头
self.classifier = nn.Linear(d_z, num_classes)
def forward(self, z_q: torch.Tensor) -> torch.Tensor:
"""
Args:
z_q: [B, L, d_z]
Returns:
logits: [B, map_size, num_classes]
"""
B = z_q.shape[0]
q = self.pos_queries.expand(B, -1, -1) # [B, map_size, d_z]
out, _ = self.cross_attn(q, z_q, z_q) # [B, map_size, d_z]
out = self.norm(out)
return self.classifier(out) # [B, map_size, num_classes]
class GinkaVQVAE(nn.Module):
"""
VQ-VAE 风格地图编码器。
将一张完整的地图([B, H*W] 整数 tile ID 序列)编码为 L 个离散码字,
输出 z [B, L, d_z] 作为 MaskGIT 模型的生成条件。
架构:
tile embedding + 位置编码
→ L 个可学习 summary token拼接到序列头部
→ Transformer EncoderPre-LN自注意力
→ 取前 L 个输出
→ 线性投影到 d_z
→ VectorQuantizer直通估计 + 熵最大化正则)
设计约束:
- 参数量目标 < 1M
- 不含解码器z 的语义由 MaskGIT 端的交叉熵损失间接约束
- z 定位为风格/多样性控制信号,而非结构重建指导
"""
def __init__(
self,
num_classes: int = 16,
L: int = 2,
K: int = 16,
d_z: int = 64,
d_model: int = 128,
nhead: int = 4,
num_layers: int = 2,
dim_ff: int = 256,
map_size: int = 13 * 13,
beta: float = 0.25,
gamma: float = 0.1,
vq_temp: float = 1.0,
):
"""
Args:
num_classes: tile 类别数(含 MASK token
L: 码字序列长度,即 z 的序列维度
K: codebook 大小(码字总数)
d_z: 码字嵌入维度
d_model: Transformer 内部维度
nhead: 注意力头数
num_layers: Transformer 层数
dim_ff: 前馈网络隐层维度
map_size: 地图 token 总数H * W
beta: 承诺损失权重
gamma: 熵正则损失权重
vq_temp: VQ 软分配 softmax 温度
"""
super().__init__()
self.L = L
self.K = K
self.d_z = d_z
self.beta = beta
self.gamma = gamma
# Tile 嵌入
self.tile_embedding = nn.Embedding(num_classes, d_model)
# 地图位置编码(仅覆盖 map_size 个位置,不含 summary token
self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model) * 0.02)
# L 个可学习 summary token拼接到序列头部
self.summary_tokens = nn.Parameter(torch.randn(1, L, d_model) * 0.02)
# Pre-LN Transformer Encoder训练更稳定
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_ff,
batch_first=True,
activation='gelu',
norm_first=True, # Pre-LN
dropout=0.0,
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# 将 Transformer 输出投影到 codebook 维度 d_z
self.proj = nn.Sequential(
nn.Linear(d_model, d_z),
nn.LayerNorm(d_z),
)
# 向量量化层
self.vq = VectorQuantizer(K=K, d_z=d_z, temp=vq_temp)
def encode(self, map: torch.Tensor) -> torch.Tensor:
"""
将地图编码为量化前的连续向量序列。
Args:
map: [B, H*W] 整数 tile ID
Returns:
z_e: [B, L, d_z] 量化前的编码向量
"""
B = map.shape[0]
x = self.tile_embedding(map) # [B, H*W, d_model]
x = x + self.pos_embedding # [B, H*W, d_model]
summary = self.summary_tokens.expand(B, -1, -1) # [B, L, d_model]
x = torch.cat([summary, x], dim=1) # [B, L+H*W, d_model]
x = self.transformer(x) # [B, L+H*W, d_model]
z_e = self.proj(x[:, :self.L]) # [B, L, d_z]
return z_e
def encode_soft(self, soft_emb: torch.Tensor) -> torch.Tensor:
"""
将软嵌入序列编码为量化前的连续向量序列(用于一致性约束)。
与 encode() 的区别:输入是已经过 softmax 加权求和得到的连续嵌入矩阵
[B, H*W, d_model],而非整数 tile ID。梯度可完整回传到调用方的 logits。
Args:
soft_emb: [B, H*W, d_model] softmax 加权 tile 嵌入(已在 d_model 空间)
Returns:
z_e: [B, L, d_z] 量化前的编码向量
"""
B = soft_emb.shape[0]
x = soft_emb + self.pos_embedding # [B, H*W, d_model]
summary = self.summary_tokens.expand(B, -1, -1) # [B, L, d_model]
x = torch.cat([summary, x], dim=1) # [B, L+H*W, d_model]
x = self.transformer(x) # [B, L+H*W, d_model]
z_e = self.proj(x[:, :self.L]) # [B, L, d_z]
return z_e
def forward(self, map: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
完整前向传播:编码 → 量化 → 计算损失。
Args:
map: [B, H*W] 整数 tile ID训练时传入完整真实地图
Returns:
z_q: [B, L, d_z] 量化后的 z含直通梯度供 MaskGIT 使用
z_e: [B, L, d_z] 量化前的连续编码向量,供一致性约束使用
indices: [B, L] 每个位置对应的码字索引
vq_loss: scalar VQ 总损失 = beta * commit_loss + gamma * entropy_loss
commit_loss: scalar
entropy_loss: scalar
"""
z_e = self.encode(map)
z_q, indices, commit_loss, entropy_loss = self.vq(z_e)
vq_loss = self.beta * commit_loss + self.gamma * entropy_loss
return z_q, z_e, indices, vq_loss, commit_loss, entropy_loss
def sample(self, B: int, device: torch.device) -> torch.Tensor:
"""
推理阶段:从 codebook 中随机均匀采样 L 个码字。
Args:
B: batch size
device: 目标设备
Returns:
z: [B, L, d_z]
"""
indices = torch.randint(0, self.K, (B, self.L), device=device)
z = self.vq.codebook(indices) # [B, L, d_z]
return z
if __name__ == "__main__":
device = torch.device("cpu")
model = GinkaVQVAE(
num_classes=16,
L=2,
K=16,
d_z=64,
d_model=128,
nhead=4,
num_layers=2,
dim_ff=256,
map_size=13 * 13,
).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"总参数量: {total_params:,} ({total_params / 1e6:.3f}M)")
# 分模块参数统计
for name, module in model.named_children():
n = sum(p.numel() for p in module.parameters())
print(f" {name}: {n:,}")
# 前向传播测试
map_input = torch.randint(0, 15, (4, 13 * 13)).to(device) # [B=4, 169]
z_q, z_e, indices, vq_loss, commit_loss, entropy_loss = model(map_input)
print(f"\nz_q shape: {z_q.shape}") # [4, 2, 64]
print(f"z_e shape: {z_e.shape}") # [4, 2, 64]
print(f"indices shape:{indices.shape}") # [4, 2]
print(f"vq_loss: {vq_loss.item():.4f}")
# 推理采样测试
z_sample = model.sample(B=4, device=device)
print(f"sample shape: {z_sample.shape}") # [4, 2, 64]