mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-21 10:21:15 +08:00
feat: EMA Codebook
This commit is contained in:
parent
14ee52fb2f
commit
f006522cf9
@ -157,21 +157,22 @@ def build_model(device: torch.device):
|
||||
z_seq_len=VQ_L
|
||||
).to(device)
|
||||
|
||||
# 六个模型参数合并到同一优化器,端到端联合训练
|
||||
all_params = (
|
||||
list(vq1.parameters()) + list(vq2.parameters()) + list(vq3.parameters()) +
|
||||
list(mg1.parameters()) + list(mg2.parameters()) + list(mg3.parameters())
|
||||
)
|
||||
optimizer = optim.AdamW(all_params, lr=LR, weight_decay=1e-4)
|
||||
# 余弦退火:从 LR 线性衰减至 MIN_LR,周期为全部训练轮数
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=MIN_LR)
|
||||
|
||||
# 三个独立 VectorQuantizer:各阶段使用自己的码本,避免语义空间相互干扰
|
||||
quantizer1 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
|
||||
quantizer2 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
|
||||
quantizer3 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
|
||||
quantizers = (quantizer1, quantizer2, quantizer3)
|
||||
|
||||
# 九个模块参数合并到同一优化器,端到端联合训练
|
||||
all_params = (
|
||||
list(vq1.parameters()) + list(vq2.parameters()) + list(vq3.parameters()) +
|
||||
list(mg1.parameters()) + list(mg2.parameters()) + list(mg3.parameters()) +
|
||||
list(quantizer1.parameters()) + list(quantizer2.parameters()) + list(quantizer3.parameters())
|
||||
)
|
||||
optimizer = optim.AdamW(all_params, lr=LR, weight_decay=1e-4)
|
||||
# 余弦退火:从 LR 线性衰减至 MIN_LR,周期为全部训练轮数
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=MIN_LR)
|
||||
|
||||
return vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler
|
||||
|
||||
def cross_entropy_loss(logits, target):
|
||||
@ -796,15 +797,9 @@ def train(device: torch.device):
|
||||
mg1.load_state_dict(ckpt["mg1"])
|
||||
mg2.load_state_dict(ckpt["mg2"])
|
||||
mg3.load_state_dict(ckpt["mg3"])
|
||||
if "quantizer1" in ckpt:
|
||||
quantizer1.load_state_dict(ckpt["quantizer1"])
|
||||
quantizer2.load_state_dict(ckpt["quantizer2"])
|
||||
quantizer3.load_state_dict(ckpt["quantizer3"])
|
||||
elif "quantizer" in ckpt:
|
||||
quantizer1.load_state_dict(ckpt["quantizer"])
|
||||
quantizer2.load_state_dict(ckpt["quantizer"])
|
||||
quantizer3.load_state_dict(ckpt["quantizer"])
|
||||
tqdm.write("Loaded legacy shared quantizer weights into quantizer1/2/3")
|
||||
quantizer1.load_state_dict(ckpt["quantizer1"])
|
||||
quantizer2.load_state_dict(ckpt["quantizer2"])
|
||||
quantizer3.load_state_dict(ckpt["quantizer3"])
|
||||
# load_optim=False 时可跳过优化器/调度器恢复(适合调整学习率后继续训练)
|
||||
if args.load_optim and "optimizer" in ckpt:
|
||||
optimizer.load_state_dict(ckpt["optimizer"])
|
||||
|
||||
@ -4,19 +4,29 @@ import torch.nn.functional as F
|
||||
from typing import Tuple
|
||||
|
||||
class VectorQuantizer(nn.Module):
|
||||
def __init__(self, K: int, d_z: int):
|
||||
"""
|
||||
Args:
|
||||
K: codebook 大小(码字数量)
|
||||
d_z: 码字嵌入维度
|
||||
temp: 软分配 softmax 温度,越小越接近 hard assignment
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
K: int,
|
||||
d_z: int,
|
||||
decay: float = 0.99,
|
||||
epsilon: float = 1e-5
|
||||
):
|
||||
super().__init__()
|
||||
self.K = K
|
||||
self.d_z = d_z
|
||||
self.decay = decay
|
||||
self.epsilon = epsilon
|
||||
|
||||
self.codebook = nn.Embedding(K, d_z)
|
||||
nn.init.uniform_(self.codebook.weight, -1.0 / K, 1.0 / K)
|
||||
self.codebook.weight.requires_grad_(False)
|
||||
|
||||
# EMA 统计量:码字访问次数与对应编码向量和。
|
||||
self.register_buffer("ema_cluster_size", torch.ones(K))
|
||||
self.register_buffer(
|
||||
"ema_weight",
|
||||
self.codebook.weight.detach().clone()
|
||||
)
|
||||
|
||||
def codebook_stats(
|
||||
self, indices: torch.Tensor
|
||||
@ -31,19 +41,32 @@ class VectorQuantizer(nn.Module):
|
||||
usage_count = one_hot.sum(dim=0)
|
||||
return perplexity, usage_rate, usage_count
|
||||
|
||||
def ema_update(self, z_flat: torch.Tensor, flat_indices: torch.Tensor):
|
||||
one_hot = F.one_hot(flat_indices, num_classes=self.K).type_as(z_flat)
|
||||
cluster_size = one_hot.sum(dim=0)
|
||||
embed_sum = one_hot.transpose(0, 1) @ z_flat
|
||||
|
||||
self.ema_cluster_size.mul_(self.decay).add_(
|
||||
cluster_size,
|
||||
alpha=1.0 - self.decay
|
||||
)
|
||||
self.ema_weight.mul_(self.decay).add_(
|
||||
embed_sum,
|
||||
alpha=1.0 - self.decay
|
||||
)
|
||||
|
||||
total_count = self.ema_cluster_size.sum()
|
||||
normalized_cluster_size = (
|
||||
(self.ema_cluster_size + self.epsilon) /
|
||||
(total_count + self.K * self.epsilon) * total_count
|
||||
)
|
||||
normalized_weight = self.ema_weight / normalized_cluster_size.unsqueeze(1)
|
||||
self.codebook.weight.data.copy_(normalized_weight)
|
||||
|
||||
def forward(
|
||||
self, z_e: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# z_e: [B, L, d_z]
|
||||
"""
|
||||
Args:
|
||||
z_e: [B, L, d_z] 编码器输出的连续向量序列
|
||||
|
||||
Returns:
|
||||
z_q_st: [B, L, d_z] 量化后向量(直通梯度)
|
||||
indices: [B, L] 每个位置对应的码字索引
|
||||
commit_loss: scalar 承诺损失 ||z_e - sg(z_q)||^2
|
||||
"""
|
||||
B, L, d_z = z_e.shape
|
||||
|
||||
z_flat = z_e.reshape(B * L, d_z) # [B * L, d_z]
|
||||
@ -58,10 +81,10 @@ class VectorQuantizer(nn.Module):
|
||||
distances = ze_square + ek_square - 2 * mul
|
||||
|
||||
# Hard assignment:取最近码字索引
|
||||
indices = distances.argmin(dim=1) # [B*L]
|
||||
flat_indices = distances.argmin(dim=1) # [B*L]
|
||||
|
||||
# 量化向量
|
||||
z_q_flat = self.codebook(indices) # [B*L, d_z]
|
||||
z_q_flat = self.codebook(flat_indices) # [B*L, d_z]
|
||||
z_q = z_q_flat.reshape(B, L, d_z)
|
||||
|
||||
# 直通估计:前向传 z_q,反向传 z_e 的梯度
|
||||
@ -70,7 +93,11 @@ class VectorQuantizer(nn.Module):
|
||||
# 承诺损失:拉近编码向量与其对应的码字(仅更新编码器)
|
||||
commit_loss = F.mse_loss(z_e, z_q.detach())
|
||||
|
||||
indices = indices.reshape(B, L)
|
||||
# 训练时使用 EMA 更新码本;验证与推理阶段保持码本固定。
|
||||
if self.training and z_e.requires_grad:
|
||||
self.ema_update(z_flat.detach(), flat_indices.detach())
|
||||
|
||||
indices = flat_indices.reshape(B, L)
|
||||
perplexity, usage_rate, usage_count = self.codebook_stats(indices)
|
||||
return z_q_st, indices, commit_loss, perplexity, usage_count
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user