Encodes the images into tokens and ids for FLUX pipeline.
(pipeline: FluxPipeline, images: torch.Tensor)
| 35 | |
| 36 | |
| 37 | def encode_images(pipeline: FluxPipeline, images: torch.Tensor): |
| 38 | """ |
| 39 | Encodes the images into tokens and ids for FLUX pipeline. |
| 40 | """ |
| 41 | images = pipeline.image_processor.preprocess(images) |
| 42 | images = images.to(pipeline.device).to(pipeline.dtype) |
| 43 | images = pipeline.vae.encode(images).latent_dist.sample() |
| 44 | images = ( |
| 45 | images - pipeline.vae.config.shift_factor |
| 46 | ) * pipeline.vae.config.scaling_factor |
| 47 | images_tokens = pipeline._pack_latents(images, *images.shape) |
| 48 | images_ids = pipeline._prepare_latent_image_ids( |
| 49 | images.shape[0], |
| 50 | images.shape[2], |
| 51 | images.shape[3], |
| 52 | pipeline.device, |
| 53 | pipeline.dtype, |
| 54 | ) |
| 55 | if images_tokens.shape[1] != images_ids.shape[0]: |
| 56 | images_ids = pipeline._prepare_latent_image_ids( |
| 57 | images.shape[0], |
| 58 | images.shape[2] // 2, |
| 59 | images.shape[3] // 2, |
| 60 | pipeline.device, |
| 61 | pipeline.dtype, |
| 62 | ) |
| 63 | return images_tokens, images_ids |
| 64 | |
| 65 | |
| 66 | depth_pipe = None |
no test coverage detected