feat: EMA Codebook

This commit is contained in:
unanmed 2026-05-20 18:45:46 +08:00
parent 14ee52fb2f
commit f006522cf9
2 changed files with 59 additions and 37 deletions

View File

@ -157,21 +157,22 @@ def build_model(device: torch.device):
z_seq_len=VQ_L
).to(device)
# 六个模型参数合并到同一优化器,端到端联合训练
all_params = (
list(vq1.parameters()) + list(vq2.parameters()) + list(vq3.parameters()) +
list(mg1.parameters()) + list(mg2.parameters()) + list(mg3.parameters())
)
optimizer = optim.AdamW(all_params, lr=LR, weight_decay=1e-4)
# 余弦退火:从 LR 线性衰减至 MIN_LR周期为全部训练轮数
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=MIN_LR)
# 三个独立 VectorQuantizer各阶段使用自己的码本避免语义空间相互干扰
quantizer1 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
quantizer2 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
quantizer3 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
quantizers = (quantizer1, quantizer2, quantizer3)
# 九个模块参数合并到同一优化器,端到端联合训练
all_params = (
list(vq1.parameters()) + list(vq2.parameters()) + list(vq3.parameters()) +
list(mg1.parameters()) + list(mg2.parameters()) + list(mg3.parameters()) +
list(quantizer1.parameters()) + list(quantizer2.parameters()) + list(quantizer3.parameters())
)
optimizer = optim.AdamW(all_params, lr=LR, weight_decay=1e-4)
# 余弦退火:从 LR 线性衰减至 MIN_LR周期为全部训练轮数
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=MIN_LR)
return vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler
def cross_entropy_loss(logits, target):
@ -796,15 +797,9 @@ def train(device: torch.device):
mg1.load_state_dict(ckpt["mg1"])
mg2.load_state_dict(ckpt["mg2"])
mg3.load_state_dict(ckpt["mg3"])
if "quantizer1" in ckpt:
quantizer1.load_state_dict(ckpt["quantizer1"])
quantizer2.load_state_dict(ckpt["quantizer2"])
quantizer3.load_state_dict(ckpt["quantizer3"])
elif "quantizer" in ckpt:
quantizer1.load_state_dict(ckpt["quantizer"])
quantizer2.load_state_dict(ckpt["quantizer"])
quantizer3.load_state_dict(ckpt["quantizer"])
tqdm.write("Loaded legacy shared quantizer weights into quantizer1/2/3")
quantizer1.load_state_dict(ckpt["quantizer1"])
quantizer2.load_state_dict(ckpt["quantizer2"])
quantizer3.load_state_dict(ckpt["quantizer3"])
# load_optim=False 时可跳过优化器/调度器恢复(适合调整学习率后继续训练)
if args.load_optim and "optimizer" in ckpt:
optimizer.load_state_dict(ckpt["optimizer"])

View File

@ -4,19 +4,29 @@ import torch.nn.functional as F
from typing import Tuple
class VectorQuantizer(nn.Module):
def __init__(self, K: int, d_z: int):
"""
Args:
K: codebook 大小码字数量
d_z: 码字嵌入维度
temp: 软分配 softmax 温度越小越接近 hard assignment
"""
def __init__(
self,
K: int,
d_z: int,
decay: float = 0.99,
epsilon: float = 1e-5
):
super().__init__()
self.K = K
self.d_z = d_z
self.decay = decay
self.epsilon = epsilon
self.codebook = nn.Embedding(K, d_z)
nn.init.uniform_(self.codebook.weight, -1.0 / K, 1.0 / K)
self.codebook.weight.requires_grad_(False)
# EMA 统计量:码字访问次数与对应编码向量和。
self.register_buffer("ema_cluster_size", torch.ones(K))
self.register_buffer(
"ema_weight",
self.codebook.weight.detach().clone()
)
def codebook_stats(
self, indices: torch.Tensor
@ -31,19 +41,32 @@ class VectorQuantizer(nn.Module):
usage_count = one_hot.sum(dim=0)
return perplexity, usage_rate, usage_count
def ema_update(self, z_flat: torch.Tensor, flat_indices: torch.Tensor):
one_hot = F.one_hot(flat_indices, num_classes=self.K).type_as(z_flat)
cluster_size = one_hot.sum(dim=0)
embed_sum = one_hot.transpose(0, 1) @ z_flat
self.ema_cluster_size.mul_(self.decay).add_(
cluster_size,
alpha=1.0 - self.decay
)
self.ema_weight.mul_(self.decay).add_(
embed_sum,
alpha=1.0 - self.decay
)
total_count = self.ema_cluster_size.sum()
normalized_cluster_size = (
(self.ema_cluster_size + self.epsilon) /
(total_count + self.K * self.epsilon) * total_count
)
normalized_weight = self.ema_weight / normalized_cluster_size.unsqueeze(1)
self.codebook.weight.data.copy_(normalized_weight)
def forward(
self, z_e: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# z_e: [B, L, d_z]
"""
Args:
z_e: [B, L, d_z] 编码器输出的连续向量序列
Returns:
z_q_st: [B, L, d_z] 量化后向量直通梯度
indices: [B, L] 每个位置对应的码字索引
commit_loss: scalar 承诺损失 ||z_e - sg(z_q)||^2
"""
B, L, d_z = z_e.shape
z_flat = z_e.reshape(B * L, d_z) # [B * L, d_z]
@ -58,10 +81,10 @@ class VectorQuantizer(nn.Module):
distances = ze_square + ek_square - 2 * mul
# Hard assignment取最近码字索引
indices = distances.argmin(dim=1) # [B*L]
flat_indices = distances.argmin(dim=1) # [B*L]
# 量化向量
z_q_flat = self.codebook(indices) # [B*L, d_z]
z_q_flat = self.codebook(flat_indices) # [B*L, d_z]
z_q = z_q_flat.reshape(B, L, d_z)
# 直通估计:前向传 z_q反向传 z_e 的梯度
@ -70,7 +93,11 @@ class VectorQuantizer(nn.Module):
# 承诺损失:拉近编码向量与其对应的码字(仅更新编码器)
commit_loss = F.mse_loss(z_e, z_q.detach())
indices = indices.reshape(B, L)
# 训练时使用 EMA 更新码本;验证与推理阶段保持码本固定。
if self.training and z_e.requires_grad:
self.ema_update(z_flat.detach(), flat_indices.detach())
indices = flat_indices.reshape(B, L)
perplexity, usage_rate, usage_count = self.codebook_stats(indices)
return z_q_st, indices, commit_loss, perplexity, usage_count