ginka-generator/ginka/vqvae/quantize.py

69 lines
2.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
class VectorQuantizer(nn.Module):
def __init__(self, K: int, d_z: int):
"""
Args:
K: codebook 大小(码字数量)
d_z: 码字嵌入维度
temp: 软分配 softmax 温度,越小越接近 hard assignment
"""
super().__init__()
self.K = K
self.d_z = d_z
self.codebook = nn.Embedding(K, d_z)
nn.init.uniform_(self.codebook.weight, -1.0 / K, 1.0 / K)
def forward(self, z_e: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# z_e: [B, L * 3, d_z]
"""
Args:
z_e: [B, L, d_z] 编码器输出的连续向量序列
Returns:
z_q_st: [B, L, d_z] 量化后向量(直通梯度)
indices: [B, L] 每个位置对应的码字索引
commit_loss: scalar 承诺损失 ||z_e - sg(z_q)||^2
"""
B, L, d_z = z_e.shape
z_flat = z_e.reshape(B * L, d_z) # [B * L * 3, d_z]
codebook_w = self.codebook.weight # [K, d_z]
# 计算 L2 距离:||z_e - e_k||^2 = ||z_e||^2 + ||e_k||^2 - 2 * z_e · e_k
# distances: [B*L, K]
ze_square = torch.sum(z_flat ** 2, dim=1, keepdim=True)
ek_square = torch.sum(codebook_w ** 2, dim=1)
mul = z_flat @ codebook_w.t()
distances = ze_square + ek_square - 2 * mul
# Hard assignment取最近码字索引
indices = distances.argmin(dim=1) # [B*L]
# 量化向量
z_q_flat = self.codebook(indices) # [B*L, d_z]
z_q = z_q_flat.reshape(B, L, d_z)
# 直通估计:前向传 z_q反向传 z_e 的梯度
z_q_st = z_e + (z_q - z_e).detach()
# 承诺损失:拉近编码向量与其对应的码字(仅更新编码器)
commit_loss = F.mse_loss(z_e, z_q.detach())
indices = indices.reshape(B, L)
return z_q_st, indices, commit_loss
def sample(self, B: int, L: int, device: torch.device) -> torch.Tensor:
indices1 = torch.randint(0, self.K, (B, L), device=device)
indices2 = torch.randint(0, self.K, (B, L), device=device)
indices3 = torch.randint(0, self.K, (B, L), device=device)
z1 = self.codebook(indices1)
z2 = self.codebook(indices2)
z3 = self.codebook(indices3)
return torch.cat([z1, z2, z3], dim=1)