
Implementation of Imagen, Google's Text-to-Image Neural Network that beats DALL-E2, in Pytorch. It is the new SOTA for text-to-image synthesis.
Architecturally, it is actually much simpler than DALL-E2. It consists of a cascading DDPM conditioned on text embeddings from a large pretrained T5 model (attention network). It also contains dynamic clipping for improved classifier free guidance, noise level conditioning, and a memory efficient unet design.
It appears neither CLIP nor prior network is needed after all. And so research continues.
AI Coffee Break with Letitia | Assembly AI | Yannic Kilcher
Please join if you are interested in helping out with the replication with the LAION community
StabilityAI for the generous sponsorship, as well as my other sponsors out there
🤗 Huggingface for their amazing transformers library. The text encoder portion is pretty much taken care of because of them
Jonathan Ho for bringing about a revolution in generative artificial intelligence through his seminal paper
Sylvain and Zachary for the Accelerate library, which this repository uses for distributed training
Jorge Gomes for helping out with the T5 loading code and advice on the correct T5 version
Katherine Crowson, for her beautiful code, which helped me understand the continuous time version of gaussian diffusion
Marunine and Netruk44, for reviewing code, sharing experimental results, and help with debugging
Marunine for providing a potential solution for a color shifting issue in the memory efficient u-nets. Thanks to Jacob for sharing experimental comparisons between the base and memory-efficient unets
Marunine for finding numerous bugs, resolving an issue with resize right, and for sharing his experimental configurations and results
MalumaDev for proposing the use of pixel shuffle upsampler to fix checkboard artifacts
Valentin for pointing out insufficient skip connections in the unet, as well as the specific method of attention conditioning in the base-unet in the appendix
BIGJUN for catching a big bug with continuous time gaussian diffusion noise level conditioning at inference time
Bingbing for identifying a bug with sampling and order of normalizing and noising with low resolution conditioning image
Kay for contributing one line command training of Imagen!
Hadrien Reynaud for testing out text-to-video on a medical dataset, sharing his results, and identifying issues!
$ pip install imagen-pytorch
import torch
from imagen_pytorch import Unet, Imagen
# unet for imagen
unet1 = Unet(
dim = 32,
cond_dim = 512,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = 3,
layer_attns = (False, True, True, True),
layer_cross_attns = (False, True, True, True)
)
unet2 = Unet(
dim = 32,
cond_dim = 512,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = (2, 4, 8, 8),
layer_attns = (False, False, False, True),
layer_cross_attns = (False, False, False, True)
)
# imagen, which contains the unets above (base unet and super resoluting ones)
imagen = Imagen(
unets = (unet1, unet2),
image_sizes = (64, 256),
timesteps = 1000,
cond_drop_prob = 0.1
).cuda()
# mock images (get a lot of this) and text encodings from large T5
text_embeds = torch.randn(4, 256, 768).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# feed images into imagen, training each unet in the cascade
for i in (1, 2):
loss = imagen(images, text_embeds = text_embeds, unet_number = i)
loss.backward()
# do the above for many many many many steps
# now you can sample an image based on the text embeddings from the cascading ddpm
images = imagen.sample(texts = [
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles'
], cond_scale = 3.)
images.shape # (3, 3, 256, 256)
For simpler training, you can directly supply text strings instead of precomputing text encodings. (Although for scaling purposes, you will definitely want to precompute the textual embeddings + mask)
The number of textual captions must match the batch size of the images if you go this route.
# mock images and text (get a lot of this)
texts = [
'a child screaming at finding a worm within a half-eaten apple',
'lizard running across the desert on two feet',
'waking up to a psychedelic landscape',
'seashells sparkling in the shallow waters'
]
images = torch.randn(4, 3, 256, 256).cuda()
# feed images into imagen, training each unet in the cascade
for i in (1, 2):
loss = imagen(images, texts = texts, unet_number = i)
loss.backward()
With the ImagenTrainer wrapper class, the exponential moving averages for all of the U-nets in the cascading DDPM will be automatically taken care of when calling update
import torch
from imagen_pytorch import Unet, Imagen, ImagenTrainer
# unet for imagen
unet1 = Unet(
dim = 32,
cond_dim = 512,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = 3,
layer_attns = (False, True, True, True),
)
unet2 = Unet(
dim = 32,
cond_dim = 512,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = (2, 4, 8, 8),
layer_attns = (False, False, False, True),
layer_cross_attns = (False, False, False, True)
)
# imagen, which contains the unets above (base unet and super resoluting ones)
imagen = Imagen(
unets = (unet1, unet2),
text_encoder_name = 't5-large',
image_sizes = (64, 256),
timesteps = 1000,
cond_drop_prob = 0.1
).cuda()
# wrap imagen with the trainer class
trainer = ImagenTrainer(imagen)
# mock images (get a lot of this) and text encodings from large T5
text_embeds = torch.randn(64, 256, 1024).cuda()
images = torch.randn(64, 3, 256, 256).cuda()
# feed images into imagen, training each unet in the cascade
loss = trainer(
images,
text_embeds = text_embeds,
unet_number = 1, # training on unet number 1 in this example, but you will have to also save checkpoints and then reload and continue training on unet number 2
max_batch_size = 4 # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
)
trainer.update(unet_number = 1)
# do the above for many many many many steps
# now you can sample an image based on the text embeddings from the cascading ddpm
images = trainer.sample(texts = [
'a puppy looking anxiously at a giant donut on the table',
'the milky way galaxy in the style of monet'
], cond_scale = 3.)
images.shape # (2, 3, 256, 256)
You can also train Imagen without text (unconditional image generation) as follows
import torch
from imagen_pytorch import Unet, Imagen, SRUnet256, ImagenTrainer
# unets for unconditional imagen
unet1 = Unet(
dim = 32,
dim_mults = (1, 2, 4),
num_resnet_blocks = 3,
layer_attns = (False, True, True),
layer_cross_attns = False,
use_linear_attn = True
)
unet2 = SRUnet256(
dim = 32,
dim_mults = (1, 2, 4),
num_resnet_blocks = (2, 4, 8),
layer_attns = (False, False, True),
layer_cross_attns = False
)
# imagen, which contains the unets above (base unet and super resoluting ones)
imagen = Imagen(
condition_on_text = False, # this must be set to False for unconditional Imagen
unets = (unet1, unet2),
image_sizes = (64, 128),
timesteps = 1000
)
trainer = ImagenTrainer(imagen).cuda()
# now get a ton of images and feed it through the Imagen trainer
training_images = torch.randn(4, 3, 256, 256).cuda()
# train each unet separately
# in this example, only training on unet number 1
loss = trainer(training_images, unet_number = 1)
trainer.update(unet_number = 1)
# do the above for many many many many steps
# now you can sample images unconditionally from the cascading unet(s)
images = trainer.sample(batch_size = 16) # (16, 3, 128, 128)
Or train only super-resoluting unets
import torch
from imagen_pytorch import Unet, NullUnet, Imagen
# unet for imagen
unet1 = NullUnet() # add a placeholder "null" unet for the base unet
unet2 = Unet(
dim = 32,
cond_dim = 512,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = (2, 4, 8, 8),
layer_attns = (False, False, False, True),
layer_cross_attns = (False, False, False, True)
)
# imagen, which contains the unets above (base unet and super resoluting ones)
imagen = Imagen(
unets = (unet1, unet2),
image_sizes = (64, 256),
timesteps = 250,
cond_drop_prob = 0.1
).cuda()
# mock images (get a lot of this) and text encodings from large T5
text_embeds = torch.randn(4, 256, 768).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# feed images into imagen, training each unet in the cascade
loss = imagen(images, text_embeds = text_embeds, unet_number = 2)
loss.backward()
# do the above for many many many many steps
# now you can sample an image based on the text embeddings as well as low resolution images
lowres_images = torch.randn(3, 3, 64, 64).cuda() # starting un-resoluted images
images = imagen.sample(
texts = [
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles'
],
start_at_unet_number = 2, # start at unet number 2
start_image_or_video = lowres_images, # pass in low resolution images to be resoluted
cond_scale = 3.)
images.shape # (3, 3, 256, 256)
At any time you can save and load the trainer and all associated states with the save and load methods. It is recommended you use these methods instead of manually saving with a state_dict call, as there are some device memory management being done underneath the hood within the trainer.
ex.
trainer.save('./path/to/checkpoint.pt')
trainer.load('./path/to/checkpoint.pt')
trainer.steps # (2,) step number for each of the unets, in this case 2
You can also rely on the ImagenTrainer to automatically train off DataLoader instances. You simply have to craft your DataLoader to return either images (for unconditional case), or of ('images', 'text_embeds') for text-guided generation.
ex. unconditional training
from imagen_pytorch import Unet, Imagen, ImagenTrainer
from imagen_pytorch.data import Dataset
# unets for unconditional imagen
unet = Unet(
dim = 32,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = 1,
layer_attns = (False, False, False, True),
layer_cross_attns = False
)
# imagen, which contains the unet above
imagen = Imagen(
condition_on_text = False, # this must be set to False for unconditional Imagen
unets = unet,
image_sizes = 128,
timesteps = 1000
)
trainer = ImagenTrainer(
imagen = imagen,
split_valid_from_train = True # whether to split the validation dataset from the training
).cuda()
# instantiate your dataloader, which returns the necessary inputs to the DDPM as tuple in the order of images, text embeddings, then text masks. in this case, only images is returned as it is unconditional training
dataset = Dataset('/path/to/training/images', image_size = 128)
trainer.add_train_dataset(dataset, batch_size = 16)
# working training loop
for i in range(200000):
loss = trainer.train_step(unet_number = 1, max_batch_size = 4)
print(f'loss: {loss}')
if not (i % 50):
valid_loss = trainer.valid_step(unet_number = 1, max_batch_size = 4)
print(f'valid loss: {valid_loss}')
if not (i % 100) and trainer.is_main: # is_main makes sure this can run in distributed
images = trainer.sample(batch_size = 1, return_pil_images = True) # returns List[Image]
images[0].save(f'./sample-{i // 100}.png')
Thanks to 🤗 Accelerate, you can do multi GPU training easily with two steps.
First you need to invoke accelerate config in the same directory as your training script (say it is named train.py)
$ accelerate config
Next, instead of calling python train.py as you would for single GPU, you would use the accelerate CLI as so
```bash $ accelerate launch train.
$ claude mcp add imagen-pytorch \
-- python -m otcore.mcp_server <graph>