MCPcopy
hub / github.com/lucidrains/vector-quantize-pytorch

github.com/lucidrains/vector-quantize-pytorch @1.27.21 sqlite

repository ↗ · DeepWiki ↗ · release 1.27.21 ↗
296 symbols 1,000 edges 22 files 25 documented · 8%
README

Vector Quantization - Pytorch

A vector quantization library originally transcribed from Deepmind's tensorflow implementation, made conveniently into a package. It uses exponential moving averages to update the dictionary.

VQ has been successfully used by Deepmind and OpenAI for high quality generation of images (VQ-VAE-2) and music (Jukebox).

Install

$ pip install vector-quantize-pytorch

Usage

import torch
from vector_quantize_pytorch import VectorQuantize

vq = VectorQuantize(
    dim = 256,
    codebook_size = 512,     # codebook size
    decay = 0.8,             # the exponential moving average decay, lower means the dictionary will change faster
    commitment_weight = 1.   # the weight on the commitment loss
)

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x) # (1, 1024, 256), (1, 1024), (1)

Residual VQ

This paper proposes to use multiple vector quantizers to recursively quantize the residuals of the waveform. You can use this with the ResidualVQ class and one extra initialization parameter.

import torch
from vector_quantize_pytorch import ResidualVQ

residual_vq = ResidualVQ(
    dim = 256,
    num_quantizers = 8,      # specify number of quantizers
    codebook_size = 1024,    # codebook size
)

x = torch.randn(1, 1024, 256)

quantized, indices, commit_loss = residual_vq(x)
print(quantized.shape, indices.shape, commit_loss.shape)
# (1, 1024, 256), (1, 1024, 8), (1, 8)

# if you need all the codes across the quantization layers, just pass return_all_codes = True

quantized, indices, commit_loss, all_codes = residual_vq(x, return_all_codes = True)

# (8, 1, 1024, 256)

Furthermore, this paper uses Residual-VQ to construct the RQ-VAE, for generating high resolution images with more compressed codes.

They make two modifications. The first is to share the codebook across all quantizers. The second is to stochastically sample the codes rather than always taking the closest match. You can use both of these features with two extra keyword arguments.

import torch
from vector_quantize_pytorch import ResidualVQ

residual_vq = ResidualVQ(
    dim = 256,
    num_quantizers = 8,
    codebook_size = 1024,
    stochastic_sample_codes = True,
    sample_codebook_temp = 0.1,         # temperature for stochastically sampling codes, 0 would be equivalent to non-stochastic
    shared_codebook = True              # whether to share the codebooks for all quantizers or not
)

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = residual_vq(x)

# (1, 1024, 256), (1, 1024, 8), (1, 8)

A recent paper further proposes to do residual VQ on groups of the feature dimension, showing equivalent results to Encodec while using far fewer codebooks. You can use it by importing GroupedResidualVQ

import torch
from vector_quantize_pytorch import GroupedResidualVQ

residual_vq = GroupedResidualVQ(
    dim = 256,
    num_quantizers = 8,      # specify number of quantizers
    groups = 2,
    codebook_size = 1024,    # codebook size
)

x = torch.randn(1, 1024, 256)

quantized, indices, commit_loss = residual_vq(x)

# (1, 1024, 256), (2, 1, 1024, 8), (2, 1, 8)

Initialization

The SoundStream paper proposes that the codebook should be initialized by the kmeans centroids of the first batch. You can easily turn on this feature with one flag kmeans_init = True, for either VectorQuantize or ResidualVQ class

import torch
from vector_quantize_pytorch import ResidualVQ

residual_vq = ResidualVQ(
    dim = 256,
    codebook_size = 256,
    num_quantizers = 4,
    kmeans_init = True,   # set to True
    kmeans_iters = 10     # number of kmeans iterations to calculate the centroids for the codebook on init
)

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = residual_vq(x)

# (1, 1024, 256), (1, 1024, 4), (1, 4)

Gradient Computation

VQ-VAEs are traditionally trained with the straight-through estimator (STE). During the backwards pass, the gradient flows around the VQ layer rather than through it. The rotation trick paper proposes to transform the gradient through the VQ layer so the relative angle and magnitude between the input vector and quantized output are encoded into the gradient. You can enable or disable this feature with rotation_trick=True/False in the VectorQuantize class.

from vector_quantize_pytorch import VectorQuantize

vq_layer = VectorQuantize(
    dim = 256,
    codebook_size = 256,
    rotation_trick = True,   # Set to False to use the STE gradient estimator or True to use the rotation trick.
)

Increasing codebook usage

This repository will contain a few techniques from various papers to combat "dead" codebook entries, which is a common problem when using vector quantizers.

Lower codebook dimension

The Improved VQGAN paper proposes to have the codebook kept in a lower dimension. The encoder values are projected down before being projected back to high dimensional after quantization. You can set this with the codebook_dim hyperparameter.

import torch
from vector_quantize_pytorch import VectorQuantize

vq = VectorQuantize(
    dim = 256,
    codebook_size = 256,
    codebook_dim = 16      # paper proposes setting this to 32 or as low as 8 to increase codebook usage
)

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)

# (1, 1024, 256), (1, 1024), (1,)

Cosine similarity

The Improved VQGAN paper also proposes to l2 normalize the codes and the encoded vectors, which boils down to using cosine similarity for the distance. They claim enforcing the vectors on a sphere leads to improvements in code usage and downstream reconstruction. You can turn this on by setting use_cosine_sim = True

import torch
from vector_quantize_pytorch import VectorQuantize

vq = VectorQuantize(
    dim = 256,
    codebook_size = 256,
    use_cosine_sim = True   # set this to True
)

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)

# (1, 1024, 256), (1, 1024), (1,)

Expiring stale codes

Finally, the SoundStream paper has a scheme where they replace codes that have hits below a certain threshold with randomly selected vector from the current batch. You can set this threshold with threshold_ema_dead_code keyword.

import torch
from vector_quantize_pytorch import VectorQuantize

vq = VectorQuantize(
    dim = 256,
    codebook_size = 512,
    threshold_ema_dead_code = 2  # should actively replace any codes that have an exponential moving average cluster size less than 2
)

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)

# (1, 1024, 256), (1, 1024), (1,)

Orthogonal regularization loss

VQ-VAE / VQ-GAN is quickly gaining popularity. A recent paper proposes that when using vector quantization on images, enforcing the codebook to be orthogonal leads to translation equivariance of the discretized codes, leading to large improvements in downstream text to image generation tasks.

You can use this feature by simply setting the orthogonal_reg_weight to be greater than 0, in which case the orthogonal regularization will be added to the auxiliary loss outputted by the module.

import torch
from vector_quantize_pytorch import VectorQuantize

vq = VectorQuantize(
    dim = 256,
    codebook_size = 256,
    accept_image_fmap = True,                   # set this true to be able to pass in an image feature map
    orthogonal_reg_weight = 10,                 # in paper, they recommended a value of 10
    orthogonal_reg_max_codes = 128,             # this would randomly sample from the codebook for the orthogonal regularization loss, for limiting memory usage
    orthogonal_reg_active_codes_only = False    # set this to True if you have a very large codebook, and would only like to enforce the loss on the activated codes per batch
)

img_fmap = torch.randn(1, 256, 32, 32)
quantized, indices, loss = vq(img_fmap) # (1, 256, 32, 32), (1, 32, 32), (1,)

# loss now contains the orthogonal regularization loss with the weight as assigned

Multi-headed VQ

There has been a number of papers that proposes variants of discrete latent representations with a multi-headed approach (multiple codes per feature). I have decided to offer one variant where the same codebook is used to vector quantize across the input dimension head times.

You can also use a more proven approach (memcodes) from NWT paper

import torch
from vector_quantize_pytorch import VectorQuantize

vq = VectorQuantize(
    dim = 256,
    codebook_dim = 32,                  # a number of papers have shown smaller codebook dimension to be acceptable
    heads = 8,                          # number of heads to vector quantize, codebook shared across all heads
    separate_codebook_per_head = True,  # whether to have a separate codebook per head. False would mean 1 shared codebook
    codebook_size = 8196,
    accept_image_fmap = True
)

img_fmap = torch.randn(1, 256, 32, 32)
quantized, indices, loss = vq(img_fmap)

# (1, 256, 32, 32), (1, 32, 32, 8), (1,)

Random Projection Quantizer

This paper first proposed to use a random projection quantizer for masked speech modeling, where signals are projected with a randomly initialized matrix and then matched with a random initialized codebook. One therefore does not need to learn the quantizer. This technique was used by Google's Universal Speech Model to achieve SOTA for speech-to-text modeling.

USM further proposes to use multiple codebook, and the masked speech modeling with a multi-softmax objective. You can do this easily by setting num_codebooks to be greater than 1

import torch
from vector_quantize_pytorch import RandomProjectionQuantizer

quantizer = RandomProjectionQuantizer(
    dim = 512,               # input dimensions
    num_codebooks = 16,      # in USM, they used up to 16 for 5% gain
    codebook_dim = 256,      # codebook dimension
    codebook_size = 1024     # codebook size
)

x = torch.randn(1, 1024, 512)
indices = quantizer(x)

# (1, 1024, 16)

This repository should also automatically synchronizing the codebooks in a multi-process setting. If somehow it isn't, please open an issue. You can override whether to synchronize codebooks or not by setting sync_codebook = True | False

Sim VQ

A new paper proposes a scheme where the codebook is frozen, and the codes are implicitly generated through a linear projection. The authors claim this setup leads to less codebook collapse as well as easier convergence. I have found this to perform even better when paired with rotation trick from Fifty et al., and expanding the linear projection to a small one layer MLP. You can experiment with it as so

Update: hearing mixed results

import torch
from vector_quantize_pytorch import SimVQ

sim_vq = SimVQ(
    dim = 512,
    codebook_size = 1024,
    rotation_trick = True  # use rotation trick from Fifty et al.
)

x = torch.randn(1, 1024, 512)
quantized, indices, commit_loss = sim_vq(x)

assert x.shape == quantized.shape
assert torch.allclose(quantized, sim_vq.indices_to_codes(indices), atol = 1e-6)

For the residual flavor, just import ResidualSimVQ instead

import torch
from vector_quantize_pytorch import ResidualSimVQ

residual_sim_vq = ResidualSimVQ(
    dim = 512,
    num_quantizers = 4,
    codebook_size = 1024,
    rotation_trick = True  # use rotation trick from Fifty et al.
)

x = torch.randn(1, 1024, 512)
quantized, indices, commit_loss = residual_sim_vq(x)

assert x.shape == quantized.shape
assert torch.allclose(quantized, residual_sim_vq.get_output_from_indices(indices), atol = 1e-6)

Finite Scalar Quantization

VQ FSQ
Quantization argmin_c || z-c || round(f(z))
Gradients Straight Through Estimation (STE) STE
Auxiliary Losses Commitment, codebook, entropy loss, ... N/A
Tricks EMA on codebook, codebook splitting, projections, ... N/A
Parameters Codebook N/A

This work out of Google Deepmind aims to vastly simplify the way vector quantization is done for generative modeling, removing the need for commitment losses, EMA updating of the codebook, as well as tackle the issues with codebook collapse or insufficient utilization. They simply round each scalar into discrete levels with straight through gradients; the codes become uniform points in a hypercube.

Thanks goes out to @sekstini for porting over this implementation in record time!

import torch
from vector_quantize_pytorch import FSQ

quantizer = FSQ(
    levels = [8, 5, 5, 5]
)

x = torch.randn(1, 1024, 4) # 4 since there are 4 levels
xhat, indices = quantizer(x)

# (1, 1024, 4), (1, 1024)

assert torch.all(xhat == quantizer.indices_to_codes(indices))

An improvised Residual FSQ, for an attempt to improve audio encoding.

Credit goes to @sekstini for originally incepting the idea [here](https://githu

Core symbols most depended-on inside this repo

exists
called by 45
vector_quantize_pytorch/vector_quantize_pytorch.py
exists
called by 14
vector_quantize_pytorch/residual_vq.py
default
called by 10
vector_quantize_pytorch/vector_quantize_pytorch.py
exists
called by 8
vector_quantize_pytorch/lookup_free_quantization.py
default
called by 8
vector_quantize_pytorch/residual_vq.py
l2norm
called by 8
vector_quantize_pytorch/vector_quantize_pytorch.py
default
called by 7
vector_quantize_pytorch/lookup_free_quantization.py
indices_to_codes
called by 7
vector_quantize_pytorch/sim_vq.py

Shape

Function 158
Method 114
Class 23
Route 1

Languages

Python100%

Modules by API surface

vector_quantize_pytorch/vector_quantize_pytorch.py59 symbols
vector_quantize_pytorch/residual_vq.py30 symbols
tests/test_readme.py30 symbols
vector_quantize_pytorch/finite_scalar_quantization.py22 symbols
vector_quantize_pytorch/residual_fsq.py19 symbols
vector_quantize_pytorch/lookup_free_quantization.py19 symbols
vector_quantize_pytorch/residual_lfq.py18 symbols
tests/test_latent_quantization.py15 symbols
vector_quantize_pytorch/residual_sim_vq.py14 symbols
vector_quantize_pytorch/latent_quantization.py14 symbols
vector_quantize_pytorch/sim_vq.py10 symbols
vector_quantize_pytorch/binary_mapper.py8 symbols

Used by 2 indexed graphs manifest dependencies, hub-wide

Dependencies from manifests, versioned

einops0.8.0 · 1×
einx0.3.0 · 1×
torch2.4 · 1×

For agents

$ claude mcp add vector-quantize-pytorch \
  -- python -m otcore.mcp_server <graph>

⬇ download graph artifact