ginka-generator/ginka/transformer/fsq.py
2026-03-10 23:06:23 +08:00

23 lines
432 B
Python

import torch
import torch.nn as nn
class FSQ(nn.Module):
def __init__(self, levels=7):
super().__init__()
self.levels = levels
self.scale = (levels - 1) / 2
def forward(self, z):
# 限制范围
z = torch.tanh(z)
# 量化
z_q = torch.round(z * self.scale) / self.scale
# Straight-through estimator
z_q = z + (z_q - z).detach()
return z_q