From f006522cf91610b8403f98b760dbf6cd13413638 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Wed, 20 May 2026 18:45:46 +0800 Subject: [PATCH] feat: EMA Codebook --- ginka/train_seperated.py | 31 ++++++++----------- ginka/vqvae/quantize.py | 65 ++++++++++++++++++++++++++++------------ 2 files changed, 59 insertions(+), 37 deletions(-) diff --git a/ginka/train_seperated.py b/ginka/train_seperated.py index 8b0cdca..f17c0ac 100644 --- a/ginka/train_seperated.py +++ b/ginka/train_seperated.py @@ -157,21 +157,22 @@ def build_model(device: torch.device): z_seq_len=VQ_L ).to(device) - # 六个模型参数合并到同一优化器,端到端联合训练 - all_params = ( - list(vq1.parameters()) + list(vq2.parameters()) + list(vq3.parameters()) + - list(mg1.parameters()) + list(mg2.parameters()) + list(mg3.parameters()) - ) - optimizer = optim.AdamW(all_params, lr=LR, weight_decay=1e-4) - # 余弦退火:从 LR 线性衰减至 MIN_LR,周期为全部训练轮数 - scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=MIN_LR) - # 三个独立 VectorQuantizer:各阶段使用自己的码本,避免语义空间相互干扰 quantizer1 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device) quantizer2 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device) quantizer3 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device) quantizers = (quantizer1, quantizer2, quantizer3) + # 九个模块参数合并到同一优化器,端到端联合训练 + all_params = ( + list(vq1.parameters()) + list(vq2.parameters()) + list(vq3.parameters()) + + list(mg1.parameters()) + list(mg2.parameters()) + list(mg3.parameters()) + + list(quantizer1.parameters()) + list(quantizer2.parameters()) + list(quantizer3.parameters()) + ) + optimizer = optim.AdamW(all_params, lr=LR, weight_decay=1e-4) + # 余弦退火:从 LR 线性衰减至 MIN_LR,周期为全部训练轮数 + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=MIN_LR) + return vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler def cross_entropy_loss(logits, target): @@ -796,15 +797,9 @@ def train(device: torch.device): mg1.load_state_dict(ckpt["mg1"]) mg2.load_state_dict(ckpt["mg2"]) mg3.load_state_dict(ckpt["mg3"]) - if "quantizer1" in ckpt: - quantizer1.load_state_dict(ckpt["quantizer1"]) - quantizer2.load_state_dict(ckpt["quantizer2"]) - quantizer3.load_state_dict(ckpt["quantizer3"]) - elif "quantizer" in ckpt: - quantizer1.load_state_dict(ckpt["quantizer"]) - quantizer2.load_state_dict(ckpt["quantizer"]) - quantizer3.load_state_dict(ckpt["quantizer"]) - tqdm.write("Loaded legacy shared quantizer weights into quantizer1/2/3") + quantizer1.load_state_dict(ckpt["quantizer1"]) + quantizer2.load_state_dict(ckpt["quantizer2"]) + quantizer3.load_state_dict(ckpt["quantizer3"]) # load_optim=False 时可跳过优化器/调度器恢复(适合调整学习率后继续训练) if args.load_optim and "optimizer" in ckpt: optimizer.load_state_dict(ckpt["optimizer"]) diff --git a/ginka/vqvae/quantize.py b/ginka/vqvae/quantize.py index b9dcaf7..4966b43 100644 --- a/ginka/vqvae/quantize.py +++ b/ginka/vqvae/quantize.py @@ -4,19 +4,29 @@ import torch.nn.functional as F from typing import Tuple class VectorQuantizer(nn.Module): - def __init__(self, K: int, d_z: int): - """ - Args: - K: codebook 大小(码字数量) - d_z: 码字嵌入维度 - temp: 软分配 softmax 温度,越小越接近 hard assignment - """ + def __init__( + self, + K: int, + d_z: int, + decay: float = 0.99, + epsilon: float = 1e-5 + ): super().__init__() self.K = K self.d_z = d_z + self.decay = decay + self.epsilon = epsilon self.codebook = nn.Embedding(K, d_z) nn.init.uniform_(self.codebook.weight, -1.0 / K, 1.0 / K) + self.codebook.weight.requires_grad_(False) + + # EMA 统计量:码字访问次数与对应编码向量和。 + self.register_buffer("ema_cluster_size", torch.ones(K)) + self.register_buffer( + "ema_weight", + self.codebook.weight.detach().clone() + ) def codebook_stats( self, indices: torch.Tensor @@ -31,19 +41,32 @@ class VectorQuantizer(nn.Module): usage_count = one_hot.sum(dim=0) return perplexity, usage_rate, usage_count + def ema_update(self, z_flat: torch.Tensor, flat_indices: torch.Tensor): + one_hot = F.one_hot(flat_indices, num_classes=self.K).type_as(z_flat) + cluster_size = one_hot.sum(dim=0) + embed_sum = one_hot.transpose(0, 1) @ z_flat + + self.ema_cluster_size.mul_(self.decay).add_( + cluster_size, + alpha=1.0 - self.decay + ) + self.ema_weight.mul_(self.decay).add_( + embed_sum, + alpha=1.0 - self.decay + ) + + total_count = self.ema_cluster_size.sum() + normalized_cluster_size = ( + (self.ema_cluster_size + self.epsilon) / + (total_count + self.K * self.epsilon) * total_count + ) + normalized_weight = self.ema_weight / normalized_cluster_size.unsqueeze(1) + self.codebook.weight.data.copy_(normalized_weight) + def forward( self, z_e: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # z_e: [B, L, d_z] - """ - Args: - z_e: [B, L, d_z] 编码器输出的连续向量序列 - - Returns: - z_q_st: [B, L, d_z] 量化后向量(直通梯度) - indices: [B, L] 每个位置对应的码字索引 - commit_loss: scalar 承诺损失 ||z_e - sg(z_q)||^2 - """ B, L, d_z = z_e.shape z_flat = z_e.reshape(B * L, d_z) # [B * L, d_z] @@ -58,10 +81,10 @@ class VectorQuantizer(nn.Module): distances = ze_square + ek_square - 2 * mul # Hard assignment:取最近码字索引 - indices = distances.argmin(dim=1) # [B*L] + flat_indices = distances.argmin(dim=1) # [B*L] # 量化向量 - z_q_flat = self.codebook(indices) # [B*L, d_z] + z_q_flat = self.codebook(flat_indices) # [B*L, d_z] z_q = z_q_flat.reshape(B, L, d_z) # 直通估计:前向传 z_q,反向传 z_e 的梯度 @@ -70,7 +93,11 @@ class VectorQuantizer(nn.Module): # 承诺损失:拉近编码向量与其对应的码字(仅更新编码器) commit_loss = F.mse_loss(z_e, z_q.detach()) - indices = indices.reshape(B, L) + # 训练时使用 EMA 更新码本;验证与推理阶段保持码本固定。 + if self.training and z_e.requires_grad: + self.ema_update(z_flat.detach(), flat_indices.detach()) + + indices = flat_indices.reshape(B, L) perplexity, usage_rate, usage_count = self.codebook_stats(indices) return z_q_st, indices, commit_loss, perplexity, usage_count