(
self,
prompt,
local_prompts=[],
masks=[],
mask_scales=[],
negative_prompt="",
cfg_scale=7.5,
input_image=None,
denoising_strength=1.0,
height=1024,
width=1024,
num_inference_steps=20,
t5_sequence_length=77,
tiled=False,
tile_size=128,
tile_stride=64,
seed=None,
progress_bar_cmd=tqdm,
progress_bar_st=None,
)
| 72 | |
| 73 | @torch.no_grad() |
| 74 | def __call__( |
| 75 | self, |
| 76 | prompt, |
| 77 | local_prompts=[], |
| 78 | masks=[], |
| 79 | mask_scales=[], |
| 80 | negative_prompt="", |
| 81 | cfg_scale=7.5, |
| 82 | input_image=None, |
| 83 | denoising_strength=1.0, |
| 84 | height=1024, |
| 85 | width=1024, |
| 86 | num_inference_steps=20, |
| 87 | t5_sequence_length=77, |
| 88 | tiled=False, |
| 89 | tile_size=128, |
| 90 | tile_stride=64, |
| 91 | seed=None, |
| 92 | progress_bar_cmd=tqdm, |
| 93 | progress_bar_st=None, |
| 94 | ): |
| 95 | height, width = self.check_resize_height_width(height, width) |
| 96 | |
| 97 | # Tiler parameters |
| 98 | tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} |
| 99 | |
| 100 | # Prepare scheduler |
| 101 | self.scheduler.set_timesteps(num_inference_steps, denoising_strength) |
| 102 | |
| 103 | # Prepare latent tensors |
| 104 | if input_image is not None: |
| 105 | self.load_models_to_device(['vae_encoder']) |
| 106 | image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) |
| 107 | latents = self.encode_image(image, **tiler_kwargs) |
| 108 | noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) |
| 109 | latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) |
| 110 | else: |
| 111 | latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) |
| 112 | |
| 113 | # Encode prompts |
| 114 | self.load_models_to_device(['text_encoder_1', 'text_encoder_2', 'text_encoder_3']) |
| 115 | prompt_emb_posi = self.encode_prompt(prompt, positive=True, t5_sequence_length=t5_sequence_length) |
| 116 | prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) |
| 117 | prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts] |
| 118 | |
| 119 | # Denoise |
| 120 | self.load_models_to_device(['dit']) |
| 121 | for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): |
| 122 | timestep = timestep.unsqueeze(0).to(self.device) |
| 123 | |
| 124 | # Classifier-free guidance |
| 125 | inference_callback = lambda prompt_emb_posi: self.dit( |
| 126 | latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, |
| 127 | ) |
| 128 | noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback) |
| 129 | noise_pred_nega = self.dit( |
| 130 | latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs, |
| 131 | ) |
nothing calls this directly
no test coverage detected