ginka-generator/ginka/vqvae/quantize.py
2026-05-20 18:45:46 +08:00

107 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
class VectorQuantizer(nn.Module):
def __init__(
self,
K: int,
d_z: int,
decay: float = 0.99,
epsilon: float = 1e-5
):
super().__init__()
self.K = K
self.d_z = d_z
self.decay = decay
self.epsilon = epsilon
self.codebook = nn.Embedding(K, d_z)
nn.init.uniform_(self.codebook.weight, -1.0 / K, 1.0 / K)
self.codebook.weight.requires_grad_(False)
# EMA 统计量:码字访问次数与对应编码向量和。
self.register_buffer("ema_cluster_size", torch.ones(K))
self.register_buffer(
"ema_weight",
self.codebook.weight.detach().clone()
)
def codebook_stats(
self, indices: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
flat_indices = indices.reshape(-1)
one_hot = F.one_hot(flat_indices, num_classes=self.K).float()
avg_probs = one_hot.mean(dim=0)
perplexity = torch.exp(
-(avg_probs * torch.log(avg_probs.clamp_min(1e-10))).sum()
)
usage_rate = (avg_probs > 0).float().mean()
usage_count = one_hot.sum(dim=0)
return perplexity, usage_rate, usage_count
def ema_update(self, z_flat: torch.Tensor, flat_indices: torch.Tensor):
one_hot = F.one_hot(flat_indices, num_classes=self.K).type_as(z_flat)
cluster_size = one_hot.sum(dim=0)
embed_sum = one_hot.transpose(0, 1) @ z_flat
self.ema_cluster_size.mul_(self.decay).add_(
cluster_size,
alpha=1.0 - self.decay
)
self.ema_weight.mul_(self.decay).add_(
embed_sum,
alpha=1.0 - self.decay
)
total_count = self.ema_cluster_size.sum()
normalized_cluster_size = (
(self.ema_cluster_size + self.epsilon) /
(total_count + self.K * self.epsilon) * total_count
)
normalized_weight = self.ema_weight / normalized_cluster_size.unsqueeze(1)
self.codebook.weight.data.copy_(normalized_weight)
def forward(
self, z_e: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# z_e: [B, L, d_z]
B, L, d_z = z_e.shape
z_flat = z_e.reshape(B * L, d_z) # [B * L, d_z]
codebook_w = self.codebook.weight # [K, d_z]
# 计算 L2 距离:||z_e - e_k||^2 = ||z_e||^2 + ||e_k||^2 - 2 * z_e · e_k
# distances: [B*L, K]
ze_square = torch.sum(z_flat ** 2, dim=1, keepdim=True)
ek_square = torch.sum(codebook_w ** 2, dim=1)
mul = z_flat @ codebook_w.t()
distances = ze_square + ek_square - 2 * mul
# Hard assignment取最近码字索引
flat_indices = distances.argmin(dim=1) # [B*L]
# 量化向量
z_q_flat = self.codebook(flat_indices) # [B*L, d_z]
z_q = z_q_flat.reshape(B, L, d_z)
# 直通估计:前向传 z_q反向传 z_e 的梯度
z_q_st = z_e + (z_q - z_e).detach()
# 承诺损失:拉近编码向量与其对应的码字(仅更新编码器)
commit_loss = F.mse_loss(z_e, z_q.detach())
# 训练时使用 EMA 更新码本;验证与推理阶段保持码本固定。
if self.training and z_e.requires_grad:
self.ema_update(z_flat.detach(), flat_indices.detach())
indices = flat_indices.reshape(B, L)
perplexity, usage_rate, usage_count = self.codebook_stats(indices)
return z_q_st, indices, commit_loss, perplexity, usage_count
def sample(self, B: int, L: int, device: torch.device) -> torch.Tensor:
indices = torch.randint(0, self.K, (B, L), device=device)
return self.codebook(indices)