MCPcopy
hub / github.com/lucidrains/DALLE-pytorch

github.com/lucidrains/DALLE-pytorch @1.6.6 sqlite

repository ↗ · DeepWiki ↗ · release 1.6.6 ↗
227 symbols 661 edges 19 files 40 documented · 18%
README

DALL-E in Pytorch

Train DALL-E w/ DeepSpeed Join us on Discord

Released DALLE Models

Web-Hostable DALLE Checkpoints

Yannic Kilcher's video

Implementation / replication of DALL-E (paper), OpenAI's Text to Image Transformer, in Pytorch. It will also contain CLIP for ranking the generations.


Quick Start

Deep Daze or Big Sleep are great alternatives!

For generating video and audio, please see NÜWA

Appreciation

This library could not have been possible without the contributions of janEbert, Clay, robvanvolt, Romaine, and Alexander! 🙏

Status

  • Hannu has managed to train a small 6 layer DALL-E on a dataset of just 2000 landscape images! (2048 visual tokens)

  • Kobiso, a research engineer from Naver, has trained on the CUB200 dataset here, using full and deepspeed sparse attention

  • (3/15/21) afiaka87 has managed one epoch using a reversible DALL-E and the dVaE here

  • TheodoreGalanos has trained on 150k layouts with the following results

  • Rom1504 has trained on 50k fashion images with captions with a really small DALL-E (2 layers) for just 24 hours with the following results

  • afiaka87 trained for 6 epochs on the same dataset as before thanks to the efficient 16k VQGAN with the following

Thanks to the amazing "mega b#6696" you can generate from this checkpoint in colab - Run inference on the Afiaka checkpoint in Colab

  • (5/2/21) First 1.3B DALL-E from 🇷🇺 has been trained and released to the public! 🎉

  • (4/8/22) Moving onwards to DALLE-2!

Install

$ pip install dalle-pytorch

Usage

Train VAE

import torch
from dalle_pytorch import DiscreteVAE

vae = DiscreteVAE(
    image_size = 256,
    num_layers = 3,           # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map)
    num_tokens = 8192,        # number of visual tokens. in the paper, they used 8192, but could be smaller for downsized projects
    codebook_dim = 512,       # codebook dimension
    hidden_dim = 64,          # hidden dimension
    num_resnet_blocks = 1,    # number of resnet blocks
    temperature = 0.9,        # gumbel softmax temperature, the lower this is, the harder the discretization
    straight_through = False, # straight-through for gumbel softmax. unclear if it is better one way or the other
)

images = torch.randn(4, 3, 256, 256)

loss = vae(images, return_loss = True)
loss.backward()

# train with a lot of data to learn a good codebook

Train DALL-E with pretrained VAE from above

import torch
from dalle_pytorch import DiscreteVAE, DALLE

vae = DiscreteVAE(
    image_size = 256,
    num_layers = 3,
    num_tokens = 8192,
    codebook_dim = 1024,
    hidden_dim = 64,
    num_resnet_blocks = 1,
    temperature = 0.9
)

dalle = DALLE(
    dim = 1024,
    vae = vae,                  # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens = 10000,    # vocab size for text
    text_seq_len = 256,         # text sequence length
    depth = 12,                 # should aim to be 64
    heads = 16,                 # attention heads
    dim_head = 64,              # attention head dimension
    attn_dropout = 0.1,         # attention dropout
    ff_dropout = 0.1            # feedforward dropout
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)

loss = dalle(text, images, return_loss = True)
loss.backward()

# do the above for a long time with a lot of data ... then

images = dalle.generate_images(text)
images.shape # (4, 3, 256, 256)

To prime with a starting crop of an image, simply pass two more arguments

img_prime = torch.randn(4, 3, 256, 256)

images = dalle.generate_images(
    text,
    img = img_prime,
    num_init_img_tokens = (14 * 32)  # you can set the size of the initial crop, defaults to a little less than ~1/2 of the tokens, as done in the paper
)

images.shape # (4, 3, 256, 256)

You may also want to generate text using DALL-E. For that call this function:

text_tokens, texts = dalle.generate_texts(tokenizer, text)

OpenAI's Pretrained VAE

You can also skip the training of the VAE altogether, using the pretrained model released by OpenAI! The wrapper class should take care of downloading and caching the model for you auto-magically.

import torch
from dalle_pytorch import OpenAIDiscreteVAE, DALLE

vae = OpenAIDiscreteVAE()       # loads pretrained OpenAI VAE

dalle = DALLE(
    dim = 1024,
    vae = vae,                  # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens = 10000,    # vocab size for text
    text_seq_len = 256,         # text sequence length
    depth = 1,                  # should aim to be 64
    heads = 16,                 # attention heads
    dim_head = 64,              # attention head dimension
    attn_dropout = 0.1,         # attention dropout
    ff_dropout = 0.1            # feedforward dropout
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)

loss = dalle(text, images, return_loss = True)
loss.backward()

Taming Transformer's Pretrained VQGAN VAE

You can also use the pretrained VAE offered by the authors of Taming Transformers! Currently only the VAE with a codebook size of 1024 is offered, with the hope that it may train a little faster than OpenAI's, which has a size of 8192.

In contrast to OpenAI's VAE, it also has an extra layer of downsampling, so the image sequence length is 256 instead of 1024 (this will lead to a 16 reduction in training costs, when you do the math). Whether it will generalize as well as the original DALL-E is up to the citizen scientists out there to discover.

Update - it works!

from dalle_pytorch import VQGanVAE

vae = VQGanVAE()

# the rest is the same as the above example

The default VQGan is the codebook size 1024 one trained on imagenet. If you wish to use a different one, you can use the vqgan_model_path and vqgan_config_path to pass the .ckpt file and the .yaml file. These options can be used both in train-dalle script or as argument of VQGanVAE class. Other pretrained VQGAN can be found in taming transformers readme. If you want to train a custom one you can follow this guide

Adjust text conditioning strength

Recently there has surfaced a new technique for guiding diffusion models without a classifier. The gist of the technique involves randomly dropping out the text condition during training, and at inference time, deriving the rough direction from unconditional to conditional distributions.

Katherine Crowson outlined in a tweet how this could work for autoregressive attention models. I have decided to include her idea in this repository for further exploration. One only has to account for two extra keyword arguments on training (null_cond_prob) and generation (cond_scale).

import torch
from dalle_pytorch import DiscreteVAE, DALLE

vae = DiscreteVAE(
    image_size = 256,
    num_layers = 3,
    num_tokens = 8192,
    codebook_dim = 1024,
    hidden_dim = 64,
    num_resnet_blocks = 1,
    temperature = 0.9
)

dalle = DALLE(
    dim = 1024,
    vae = vae,
    num_text_tokens = 10000,
    text_seq_len = 256,
    depth = 12,
    heads = 16,
    dim_head = 64,
    attn_dropout = 0.1,
    ff_dropout = 0.1
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)

loss = dalle(
    text,
    images,
    return_loss = True,
    null_cond_prob = 0.2  # firstly, set this to the probability of dropping out the condition, 20% is recommended as a default
)

loss.backward()

# do the above for a long time with a lot of data ... then

images = dalle.generate_images(
    text,
    cond_scale = 3. # secondly, set this to a value greater than 1 to increase the conditioning beyond average
)

images.shape # (4, 3, 256, 256)

That's it!

Ranking the generations

Train CLIP

import torch
from dalle_pytorch import CLIP

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 10000,
    text_enc_depth = 6,
    text_seq_len = 256,
    text_heads = 8,
    num_visual_tokens = 512,
    visual_enc_depth = 6,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)
mask = torch.ones_like(text).bool()

loss = clip(text, images, text_mask = mask, return_loss = True)
loss.backward()

To get the similarity scores from your trained Clipper, just do

images, scores = dalle.generate_images(text, mask = mask, clip = clip)

scores.shape # (2,)
images.shape # (2, 3, 256, 256)

# do your topk here, in paper they sampled 512 and chose top 32

Or you can just use the official CLIP model to rank the images from DALL-E

Scaling depth

In the blog post, they used 64 layers to achieve their results. I added reversible networks, from the Reformer paper, in order for users to attempt to scale depth at the cost of compute. Reversible networks allow you to scale to any depth at no memory cost, but a little over 2x compute cost (each layer is rerun on the backward pass).

Simply set the reversible keyword to True for the DALLE class

dalle = DALLE(
    dim = 1024,
    vae = vae,
    num_text_tokens = 10000,
    text_seq_len = 256,
    depth = 64,
    heads = 16,
    reversible = True  # <-- reversible networks https://arxiv.org/abs/2001.04451
)

Sparse Attention

The blogpost alluded to a mixture of different types of sparse attention, used mainly on the image (while the text presumably had full causal attention). I have done my best to replicate these types of sparse attention, on the scant details released. Primarily, it seems as though they are doing causal axial row / column attention, combined with a causal convolution-like attention.

By default DALLE will use full attention for all layers, but you can specify the attention type per layer as follows.

  • full full attention

  • axial_row axial attention, along the rows of the image feature map

  • axial_col axial attention, along the columns of the image feature map

  • conv_like convolution-like attention, for the image feature map

The sparse attention only applies to the image. Text will always receive full attention, as said in the blogpost.

dalle = DALLE(
    dim = 1024,
    vae = vae,
    num_text_tokens = 10000,
    text_seq_len = 256,
    depth = 64,
    heads = 16,
    reversible = True,
    attn_types = ('full', 'axial_row', 'axial_col', 'conv_like')  # cycles between these four types of attention
)

Deepspeed Sparse Attention

You can also train with Microsoft Deepspeed's <a href="https://www.deepspe

Core symbols most depended-on inside this repo

exists
called by 11
dalle_pytorch/dalle_pytorch.py
exists
called by 10
dalle_pytorch/attention.py
get_world_size
called by 8
dalle_pytorch/distributed_backends/distributed_backend.py
is_root_worker
called by 8
dalle_pytorch/distributed_backends/distributed_backend.py
backward
called by 6
dalle_pytorch/reversible.py
encode
called by 6
dalle_pytorch/tokenizer.py
exists
called by 6
dalle_pytorch/transformer.py
require_init
called by 6
dalle_pytorch/distributed_backends/distributed_backend.py

Shape

Method 140
Function 52
Class 35

Languages

Python100%

Modules by API surface

dalle_pytorch/dalle_pytorch.py37 symbols
dalle_pytorch/transformer.py30 symbols
dalle_pytorch/tokenizer.py26 symbols
dalle_pytorch/distributed_backends/distributed_backend.py22 symbols
dalle_pytorch/vae.py21 symbols
dalle_pytorch/reversible.py18 symbols
dalle_pytorch/attention.py18 symbols
dalle_pytorch/distributed_backends/deepspeed_backend.py13 symbols
dalle_pytorch/distributed_backends/horovod_backend.py10 symbols
dalle_pytorch/distributed_backends/dummy_backend.py10 symbols
train_dalle.py9 symbols
dalle_pytorch/loader.py7 symbols

For agents

$ claude mcp add DALLE-pytorch \
  -- python -m otcore.mcp_server <graph>

⬇ download graph artifact