mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-22 19:31:12 +08:00
feat: EMA Codebook
This commit is contained in:
parent
14ee52fb2f
commit
f006522cf9
@ -157,21 +157,22 @@ def build_model(device: torch.device):
|
|||||||
z_seq_len=VQ_L
|
z_seq_len=VQ_L
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
# 六个模型参数合并到同一优化器,端到端联合训练
|
|
||||||
all_params = (
|
|
||||||
list(vq1.parameters()) + list(vq2.parameters()) + list(vq3.parameters()) +
|
|
||||||
list(mg1.parameters()) + list(mg2.parameters()) + list(mg3.parameters())
|
|
||||||
)
|
|
||||||
optimizer = optim.AdamW(all_params, lr=LR, weight_decay=1e-4)
|
|
||||||
# 余弦退火:从 LR 线性衰减至 MIN_LR,周期为全部训练轮数
|
|
||||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=MIN_LR)
|
|
||||||
|
|
||||||
# 三个独立 VectorQuantizer:各阶段使用自己的码本,避免语义空间相互干扰
|
# 三个独立 VectorQuantizer:各阶段使用自己的码本,避免语义空间相互干扰
|
||||||
quantizer1 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
|
quantizer1 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
|
||||||
quantizer2 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
|
quantizer2 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
|
||||||
quantizer3 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
|
quantizer3 = VectorQuantizer(K=VQ_K, d_z=VQ_D_Z).to(device)
|
||||||
quantizers = (quantizer1, quantizer2, quantizer3)
|
quantizers = (quantizer1, quantizer2, quantizer3)
|
||||||
|
|
||||||
|
# 九个模块参数合并到同一优化器,端到端联合训练
|
||||||
|
all_params = (
|
||||||
|
list(vq1.parameters()) + list(vq2.parameters()) + list(vq3.parameters()) +
|
||||||
|
list(mg1.parameters()) + list(mg2.parameters()) + list(mg3.parameters()) +
|
||||||
|
list(quantizer1.parameters()) + list(quantizer2.parameters()) + list(quantizer3.parameters())
|
||||||
|
)
|
||||||
|
optimizer = optim.AdamW(all_params, lr=LR, weight_decay=1e-4)
|
||||||
|
# 余弦退火:从 LR 线性衰减至 MIN_LR,周期为全部训练轮数
|
||||||
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=MIN_LR)
|
||||||
|
|
||||||
return vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler
|
return vq1, vq2, vq3, mg1, mg2, mg3, quantizers, optimizer, scheduler
|
||||||
|
|
||||||
def cross_entropy_loss(logits, target):
|
def cross_entropy_loss(logits, target):
|
||||||
@ -796,15 +797,9 @@ def train(device: torch.device):
|
|||||||
mg1.load_state_dict(ckpt["mg1"])
|
mg1.load_state_dict(ckpt["mg1"])
|
||||||
mg2.load_state_dict(ckpt["mg2"])
|
mg2.load_state_dict(ckpt["mg2"])
|
||||||
mg3.load_state_dict(ckpt["mg3"])
|
mg3.load_state_dict(ckpt["mg3"])
|
||||||
if "quantizer1" in ckpt:
|
quantizer1.load_state_dict(ckpt["quantizer1"])
|
||||||
quantizer1.load_state_dict(ckpt["quantizer1"])
|
quantizer2.load_state_dict(ckpt["quantizer2"])
|
||||||
quantizer2.load_state_dict(ckpt["quantizer2"])
|
quantizer3.load_state_dict(ckpt["quantizer3"])
|
||||||
quantizer3.load_state_dict(ckpt["quantizer3"])
|
|
||||||
elif "quantizer" in ckpt:
|
|
||||||
quantizer1.load_state_dict(ckpt["quantizer"])
|
|
||||||
quantizer2.load_state_dict(ckpt["quantizer"])
|
|
||||||
quantizer3.load_state_dict(ckpt["quantizer"])
|
|
||||||
tqdm.write("Loaded legacy shared quantizer weights into quantizer1/2/3")
|
|
||||||
# load_optim=False 时可跳过优化器/调度器恢复(适合调整学习率后继续训练)
|
# load_optim=False 时可跳过优化器/调度器恢复(适合调整学习率后继续训练)
|
||||||
if args.load_optim and "optimizer" in ckpt:
|
if args.load_optim and "optimizer" in ckpt:
|
||||||
optimizer.load_state_dict(ckpt["optimizer"])
|
optimizer.load_state_dict(ckpt["optimizer"])
|
||||||
|
|||||||
@ -4,19 +4,29 @@ import torch.nn.functional as F
|
|||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
class VectorQuantizer(nn.Module):
|
class VectorQuantizer(nn.Module):
|
||||||
def __init__(self, K: int, d_z: int):
|
def __init__(
|
||||||
"""
|
self,
|
||||||
Args:
|
K: int,
|
||||||
K: codebook 大小(码字数量)
|
d_z: int,
|
||||||
d_z: 码字嵌入维度
|
decay: float = 0.99,
|
||||||
temp: 软分配 softmax 温度,越小越接近 hard assignment
|
epsilon: float = 1e-5
|
||||||
"""
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.K = K
|
self.K = K
|
||||||
self.d_z = d_z
|
self.d_z = d_z
|
||||||
|
self.decay = decay
|
||||||
|
self.epsilon = epsilon
|
||||||
|
|
||||||
self.codebook = nn.Embedding(K, d_z)
|
self.codebook = nn.Embedding(K, d_z)
|
||||||
nn.init.uniform_(self.codebook.weight, -1.0 / K, 1.0 / K)
|
nn.init.uniform_(self.codebook.weight, -1.0 / K, 1.0 / K)
|
||||||
|
self.codebook.weight.requires_grad_(False)
|
||||||
|
|
||||||
|
# EMA 统计量:码字访问次数与对应编码向量和。
|
||||||
|
self.register_buffer("ema_cluster_size", torch.ones(K))
|
||||||
|
self.register_buffer(
|
||||||
|
"ema_weight",
|
||||||
|
self.codebook.weight.detach().clone()
|
||||||
|
)
|
||||||
|
|
||||||
def codebook_stats(
|
def codebook_stats(
|
||||||
self, indices: torch.Tensor
|
self, indices: torch.Tensor
|
||||||
@ -31,19 +41,32 @@ class VectorQuantizer(nn.Module):
|
|||||||
usage_count = one_hot.sum(dim=0)
|
usage_count = one_hot.sum(dim=0)
|
||||||
return perplexity, usage_rate, usage_count
|
return perplexity, usage_rate, usage_count
|
||||||
|
|
||||||
|
def ema_update(self, z_flat: torch.Tensor, flat_indices: torch.Tensor):
|
||||||
|
one_hot = F.one_hot(flat_indices, num_classes=self.K).type_as(z_flat)
|
||||||
|
cluster_size = one_hot.sum(dim=0)
|
||||||
|
embed_sum = one_hot.transpose(0, 1) @ z_flat
|
||||||
|
|
||||||
|
self.ema_cluster_size.mul_(self.decay).add_(
|
||||||
|
cluster_size,
|
||||||
|
alpha=1.0 - self.decay
|
||||||
|
)
|
||||||
|
self.ema_weight.mul_(self.decay).add_(
|
||||||
|
embed_sum,
|
||||||
|
alpha=1.0 - self.decay
|
||||||
|
)
|
||||||
|
|
||||||
|
total_count = self.ema_cluster_size.sum()
|
||||||
|
normalized_cluster_size = (
|
||||||
|
(self.ema_cluster_size + self.epsilon) /
|
||||||
|
(total_count + self.K * self.epsilon) * total_count
|
||||||
|
)
|
||||||
|
normalized_weight = self.ema_weight / normalized_cluster_size.unsqueeze(1)
|
||||||
|
self.codebook.weight.data.copy_(normalized_weight)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, z_e: torch.Tensor
|
self, z_e: torch.Tensor
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
# z_e: [B, L, d_z]
|
# z_e: [B, L, d_z]
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
z_e: [B, L, d_z] 编码器输出的连续向量序列
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
z_q_st: [B, L, d_z] 量化后向量(直通梯度)
|
|
||||||
indices: [B, L] 每个位置对应的码字索引
|
|
||||||
commit_loss: scalar 承诺损失 ||z_e - sg(z_q)||^2
|
|
||||||
"""
|
|
||||||
B, L, d_z = z_e.shape
|
B, L, d_z = z_e.shape
|
||||||
|
|
||||||
z_flat = z_e.reshape(B * L, d_z) # [B * L, d_z]
|
z_flat = z_e.reshape(B * L, d_z) # [B * L, d_z]
|
||||||
@ -58,10 +81,10 @@ class VectorQuantizer(nn.Module):
|
|||||||
distances = ze_square + ek_square - 2 * mul
|
distances = ze_square + ek_square - 2 * mul
|
||||||
|
|
||||||
# Hard assignment:取最近码字索引
|
# Hard assignment:取最近码字索引
|
||||||
indices = distances.argmin(dim=1) # [B*L]
|
flat_indices = distances.argmin(dim=1) # [B*L]
|
||||||
|
|
||||||
# 量化向量
|
# 量化向量
|
||||||
z_q_flat = self.codebook(indices) # [B*L, d_z]
|
z_q_flat = self.codebook(flat_indices) # [B*L, d_z]
|
||||||
z_q = z_q_flat.reshape(B, L, d_z)
|
z_q = z_q_flat.reshape(B, L, d_z)
|
||||||
|
|
||||||
# 直通估计:前向传 z_q,反向传 z_e 的梯度
|
# 直通估计:前向传 z_q,反向传 z_e 的梯度
|
||||||
@ -70,7 +93,11 @@ class VectorQuantizer(nn.Module):
|
|||||||
# 承诺损失:拉近编码向量与其对应的码字(仅更新编码器)
|
# 承诺损失:拉近编码向量与其对应的码字(仅更新编码器)
|
||||||
commit_loss = F.mse_loss(z_e, z_q.detach())
|
commit_loss = F.mse_loss(z_e, z_q.detach())
|
||||||
|
|
||||||
indices = indices.reshape(B, L)
|
# 训练时使用 EMA 更新码本;验证与推理阶段保持码本固定。
|
||||||
|
if self.training and z_e.requires_grad:
|
||||||
|
self.ema_update(z_flat.detach(), flat_indices.detach())
|
||||||
|
|
||||||
|
indices = flat_indices.reshape(B, L)
|
||||||
perplexity, usage_rate, usage_count = self.codebook_stats(indices)
|
perplexity, usage_rate, usage_count = self.codebook_stats(indices)
|
||||||
return z_q_st, indices, commit_loss, perplexity, usage_count
|
return z_q_st, indices, commit_loss, perplexity, usage_count
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user