(
self,
prompt,
negative_prompt="",
input_video=None,
denoising_strength=1.0,
seed=None,
rand_device="cpu",
height=544,
width=992,
num_frames=204,
cfg_scale=9.0,
num_inference_steps=30,
tiled=True,
tile_size=(34, 34),
tile_stride=(16, 16),
smooth_scale=0.6,
progress_bar_cmd=lambda x: x,
progress_bar_st=None,
)
| 149 | |
| 150 | @torch.no_grad() |
| 151 | def __call__( |
| 152 | self, |
| 153 | prompt, |
| 154 | negative_prompt="", |
| 155 | input_video=None, |
| 156 | denoising_strength=1.0, |
| 157 | seed=None, |
| 158 | rand_device="cpu", |
| 159 | height=544, |
| 160 | width=992, |
| 161 | num_frames=204, |
| 162 | cfg_scale=9.0, |
| 163 | num_inference_steps=30, |
| 164 | tiled=True, |
| 165 | tile_size=(34, 34), |
| 166 | tile_stride=(16, 16), |
| 167 | smooth_scale=0.6, |
| 168 | progress_bar_cmd=lambda x: x, |
| 169 | progress_bar_st=None, |
| 170 | ): |
| 171 | # Tiler parameters |
| 172 | tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} |
| 173 | |
| 174 | # Scheduler |
| 175 | self.scheduler.set_timesteps(num_inference_steps, denoising_strength) |
| 176 | |
| 177 | # Initialize noise |
| 178 | latents = self.generate_noise((1, max(num_frames//17*3, 1), 64, height//16, width//16), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device) |
| 179 | |
| 180 | # Encode prompts |
| 181 | self.load_models_to_device(["text_encoder_1", "text_encoder_2"]) |
| 182 | prompt_emb_posi = self.encode_prompt(prompt, positive=True) |
| 183 | if cfg_scale != 1.0: |
| 184 | prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False) |
| 185 | |
| 186 | # Denoise |
| 187 | self.load_models_to_device(["dit"]) |
| 188 | for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): |
| 189 | timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) |
| 190 | print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}") |
| 191 | |
| 192 | # Inference |
| 193 | noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi) |
| 194 | if cfg_scale != 1.0: |
| 195 | noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega) |
| 196 | noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) |
| 197 | else: |
| 198 | noise_pred = noise_pred_posi |
| 199 | |
| 200 | # Scheduler |
| 201 | latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents) |
| 202 | |
| 203 | # Decode |
| 204 | self.load_models_to_device(['vae']) |
| 205 | frames = self.vae.decode(latents, device=self.device, smooth_scale=smooth_scale, **tiler_kwargs) |
| 206 | self.load_models_to_device([]) |
| 207 | frames = self.tensor2video(frames[0]) |
| 208 |
nothing calls this directly
no test coverage detected