ginka-generator/ginka/vqvae/quantize.py
unanmed 068940cae0 refactor: 采用 VQ + MaskGIT 方案
Co-authored-by: Copilot <copilot@github.com>
2026-04-26 23:45:56 +08:00

82 lines
3.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
class VectorQuantizer(nn.Module):
"""
向量量化层Vector Quantization
将连续的编码向量序列映射到离散的 codebook 码字索引,
并通过直通估计Straight-Through Estimator保持梯度流。
均匀分布正则化采用软分配熵最大化方案:
通过对距离做 softmax 得到软分配概率,计算平均码字使用率的熵,
最小化负熵以鼓励所有码字被均等使用。
"""
def __init__(self, K: int, d_z: int, temp: float = 1.0):
"""
Args:
K: codebook 大小(码字数量)
d_z: 码字嵌入维度
temp: 软分配 softmax 温度,越小越接近 hard assignment
"""
super().__init__()
self.K = K
self.d_z = d_z
self.temp = temp
self.codebook = nn.Embedding(K, d_z)
nn.init.uniform_(self.codebook.weight, -1.0 / K, 1.0 / K)
def forward(self, z_e: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
z_e: [B, L, d_z] 编码器输出的连续向量序列
Returns:
z_q_st: [B, L, d_z] 量化后向量(直通梯度)
indices: [B, L] 每个位置对应的码字索引
commit_loss: scalar 承诺损失 ||z_e - sg(z_q)||^2
entropy_loss: scalar 负熵损失(最小化 = 最大化码字使用均匀度)
"""
B, L, d_z = z_e.shape
# 展平到 [B*L, d_z]
z_flat = z_e.reshape(B * L, d_z)
codebook_w = self.codebook.weight # [K, d_z]
# 计算 L2 距离:||z_e - e_k||^2 = ||z_e||^2 + ||e_k||^2 - 2 * z_e · e_k
# distances: [B*L, K]
distances = (
(z_flat ** 2).sum(dim=1, keepdim=True) # [B*L, 1]
+ (codebook_w ** 2).sum(dim=1) # [K]
- 2.0 * z_flat @ codebook_w.t() # [B*L, K]
)
# Hard assignment取最近码字索引
indices = distances.argmin(dim=1) # [B*L]
# 量化向量
z_q_flat = self.codebook(indices) # [B*L, d_z]
z_q = z_q_flat.reshape(B, L, d_z)
# 直通估计:前向传 z_q反向传 z_e 的梯度
z_q_st = z_e + (z_q - z_e).detach()
# 承诺损失:拉近编码向量与其对应的码字(仅更新编码器)
commit_loss = F.mse_loss(z_e, z_q.detach())
# 熵最大化正则:通过软分配计算平均码字使用率,最小化负熵
# soft_assign: [B*L, K],对距离做 softmax距离越小概率越大
soft_assign = F.softmax(-distances / self.temp, dim=1)
avg_assign = soft_assign.mean(dim=0) # [K],平均码字使用率
# entropy_loss = -H(p) = sum(p * log(p)),最小化即最大化熵
entropy_loss = (avg_assign * torch.log(avg_assign + 1e-10)).sum()
indices = indices.reshape(B, L)
return z_q_st, indices, commit_loss, entropy_loss