MCPcopy Index your code
hub / github.com/lucidrains/x-transformers

github.com/lucidrains/x-transformers @2.16.1

repository ↗ · DeepWiki ↗ · release 2.16.1 ↗ · Ask this repo → · + Follow
492 symbols 1,714 edges 27 files 16 documented · 3% 1 cross-repo links updated 2d ago2.16.1 · 2026-02-12★ 5,92171 open issues
README

x-transformers

PyPI version

A concise but fully-featured transformer, complete with a set of promising experimental features from various papers.

Install

$ pip install x-transformers

Usage

Full encoder / decoder

import torch
from x_transformers import XTransformer

model = XTransformer(
    dim = 512,
    enc_num_tokens = 256,
    enc_depth = 6,
    enc_heads = 8,
    enc_max_seq_len = 1024,
    dec_num_tokens = 256,
    dec_depth = 6,
    dec_heads = 8,
    dec_max_seq_len = 1024,
    tie_token_emb = True      # tie embeddings of encoder and decoder
)

src = torch.randint(0, 256, (1, 1024))
src_mask = torch.ones_like(src).bool()
tgt = torch.randint(0, 256, (1, 1024))

loss = model(src, tgt, mask = src_mask) # (1, 1024, 512)
loss.backward()

Decoder-only (GPT-like)

import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
).cuda()

x = torch.randint(0, 256, (1, 1024)).cuda()

model(x) # (1, 1024, 20000)

GPT3 would be approximately the following (but you wouldn't be able to run it anyways)


gpt3 = TransformerWrapper(
    num_tokens = 50000,
    max_seq_len = 2048,
    attn_layers = Decoder(
        dim = 12288,
        depth = 96,
        heads = 96,
        attn_dim_head = 128
    )
).cuda()

Encoder-only (BERT-like)

import torch
from x_transformers import TransformerWrapper, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
).cuda()

x = torch.randint(0, 256, (1, 1024)).cuda()
mask = torch.ones_like(x).bool()

model(x, mask = mask) # (1, 1024, 20000)

State of the art image classification (SimpleViT)

import torch
from x_transformers import ViTransformerWrapper, Encoder

model = ViTransformerWrapper(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8,
    )
)

img = torch.randn(1, 3, 256, 256)
model(img) # (1, 1000)

Image -> caption

import torch
from x_transformers import ViTransformerWrapper, TransformerWrapper, Encoder, Decoder

encoder = ViTransformerWrapper(
    image_size = 256,
    patch_size = 32,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

decoder = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        cross_attend = True
    )
)

img = torch.randn(1, 3, 256, 256)
caption = torch.randint(0, 20000, (1, 1024))

encoded = encoder(img, return_embeddings = True)
decoder(caption, context = encoded) # (1, 1024, 20000)

PaLI, state of the art language-vision model

import torch
from x_transformers import ViTransformerWrapper, XTransformer, Encoder

# PaLI composes of
# 1. vision transformer (ViTransformerWrapper) +
# 2. encoder-decoder transformer (XTransformer)

vit = ViTransformerWrapper(
    image_size = 256,
    patch_size = 32,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

pali = XTransformer(
    dim = 512,
    enc_num_tokens = 256,
    enc_depth = 6,
    enc_heads = 8,
    enc_max_seq_len = 1024,
    dec_num_tokens = 256,
    dec_depth = 6,
    dec_heads = 8,
    dec_max_seq_len = 1024
)

# training data

img = torch.randn(1, 3, 256, 256)               # images
prompt = torch.randint(0, 256, (1, 1024))       # prompt
prompt_mask = torch.ones(1, 1024).bool()        # prompt text mask
output_text = torch.randint(0, 256, (1, 1024))  # target output text

# train

img_embeds = vit(
    img,
    return_embeddings = True
)

loss = pali(
    prompt,
    output_text,
    mask = prompt_mask,
    src_prepend_embeds = img_embeds             # will preprend image embeddings to encoder text embeddings before attention
)

loss.backward()

# do the above for many steps on a 17B parameter model
# attention is all you need

Dropouts

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    emb_dropout = 0.1,         # dropout after embedding
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        layer_dropout = 0.1,   # stochastic depth - dropout entire layer
        attn_dropout = 0.1,    # dropout post-attention
        ff_dropout = 0.1       # feedforward dropout
    )
)

x = torch.randint(0, 20000, (1, 1024))
model(x)

Features

Flash Attention

What originally started off as a short paper from Markus Rabe culminated as a practical fused attention CUDA kernel, named Flash Attention by Tri Dao.

The technique processes the attention matrix in tiles, only keeping track of the running softmax and exponentiated weighted sums. By recomputing on the backwards pass in a tiled fashion, one is able to keep the memory linear with respect to sequence length. This allows a lot of recent models to be able to reach for longer context lengths without worrying about the memory bottleneck.

Other engineering decisions made by Tri Dao led to its enormous success, namely minimizing HBM accesses so that both the forwards and backwards outperform naive attention. In other words, flash attention is not only more memory efficient, but faster as well, making it a necessity for training transformers.

MetaAI has recently added the ability to use Tri Dao's CUDA kernel through the scaled_dot_product_attention function in Pytorch 2.0. (They also have a mem_efficient attention, which is identical to flash attention design, just that the tiles are traversed differently)

Llama was trained using Flash Attention. The only reason to avoid it is if you require operating on the attention matrix (dynamic positional bias, talking heads, residual attention).

You can use it in this repository by setting attn_flash to True and enjoy the immediate memory savings and increase in speed.

ex.

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        attn_flash = True # just set this to True if you have pytorch 2.0 installed
    )
)

Augmenting Self-attention with Persistent Memory

https://arxiv.org/abs/1907.01470

Proposes adding learned memory key / values prior to attention. They were able to remove feedforwards altogether and attain similar performance to the original transformers. I have found that keeping the feedforwards and adding the memory key / values leads to even better performance.

from x_transformers import Decoder, Encoder

enc = Encoder(
    dim = 512,
    depth = 6,
    heads = 8,
    attn_num_mem_kv = 16 # 16 memory key / values
)

Memory Transformers

https://arxiv.org/abs/2006.11527

Proposes adding learned tokens, akin to CLS tokens, named memory tokens, that is passed through the attention layers alongside the input tokens. This setting is compatible with both encoder and decoder training.

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    num_memory_tokens = 20, # 20 memory tokens
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

Update: MetaAI researchers have found that adding memory tokens (they call them register tokens), alleviates outliers (which is suspected now to be a pathology of attention networks unable to attend to nothing).

Update 2: a hybrid architecture out of Nvidia named Hymba used memory tokens successfully in the autoregressive case, termed meta tokens in their paper.

Update 3: further corroborated by a paper trying to extend memory in attention networks, termed persistent memory

Transformers Without Tears

https://arxiv.org/abs/1910.05895

They experiment with alternatives to Layer normalization and found one that is both effective and simpler. Researchers have shared with me this leads to faster convergence.

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        use_scalenorm = True # set to True to use for all layers
    )
)

You can also use the l2 normalized embeddings proposed as part of fixnorm. I have found it leads to improved convergence, when paired with small initialization (proposed by BlinkDL). The small initialization will be taken care of as long as l2norm_embed is set to True

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    l2norm_embed = True,    # set this to True for l2 normalized embedding + small init
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

Along the same lines of l2 normalized embeddings, Huggingface's 175B parameter BLOOM also places a layernorm right after the embeddings and just before the tokens enter the attention layers. This was corroborated by Yandex's 100B parameter YaLM to stabilize training.

It is recommended you either have either l2norm_embed or post_emb_norm set to True but not both, as they probably serve the same purpose.

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    post_emb_norm = True,    # set this to True to layernorm summed token + pos embeddings
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

Root Mean Square Layer Normalization

https://arxiv.org/abs/1910.07467

The authors propose to replace layer normalization with a simpler alternative, without mean centering and the learned bias. An investigative paper found this to be the best performing normalization variant. It was also used in Deepmind's latest large language models, Retro and Gopher.

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        use_rmsnorm = True # set to true to use for all layers
    )
)

July 2023 A linear attention paper has experiments to show that removing the learned multiplicative gamma led to no performance degradation. This simplifies the RMS normalization to a satisfying l2norm(x) * sqrt(dim).

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        use_simple_rmsnorm = True # set to true to use for all layers
    )
)

GLU Variants Improve Transformer

https://arxiv.org/abs/2002.05202

Noam Shazeer paper that explores gating in the feedforward, finding that simple gating with GELU leads to significant improvements. This variant also showed up in the latest mT5 architecture. You should always turn this on (I may eventually turn it on by default).

import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        ff_glu = True # set to true to use for all feedforwards
    )
)

The PaLM language model also chose to use the Swish GLU variant. You can turn this on by setting two flags

```python import torch from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper( num_tokens = 20000, max_seq_len = 1024, attn_layers = Decoder( dim = 512, depth = 6, heads = 8, ff_swish = True, # set this to True

Core symbols most depended-on inside this repo

exists
called by 164
x_transformers/x_transformers.py
exists
called by 32
x_transformers/attend.py
default
called by 32
x_transformers/x_transformers.py
parameters
called by 25
x_transformers/dpo.py
exists
called by 24
x_transformers/continuous.py
pad_at_dim
called by 19
x_transformers/x_transformers.py
backward
called by 19
x_transformers/belief_state_wrapper.py
exists
called by 15
x_transformers/multi_input.py

Shape

Function 206
Method 204
Class 81
Route 1

Languages

Python100%

Modules by API surface

x_transformers/x_transformers.py196 symbols
tests/test_x_transformers.py71 symbols
x_transformers/autoregressive_wrapper.py23 symbols
x_transformers/attend.py22 symbols
x_transformers/nonautoregressive_wrapper.py19 symbols
x_transformers/up_wrapper.py15 symbols
x_transformers/free_transformer.py15 symbols
x_transformers/continuous.py14 symbols
x_transformers/belief_state_wrapper.py10 symbols
x_transformers/xval.py9 symbols
x_transformers/dpo.py9 symbols
x_transformers/neo_mlp.py8 symbols

Used by 1 indexed graphs manifest dependencies, hub-wide

For agents

$ claude mcp add x-transformers \
  -- python -m otcore.mcp_server <graph>

⬇ download graph artifact