| 554 | |
| 555 | @torch.no_grad() |
| 556 | def generate( |
| 557 | pipeline: FluxPipeline, |
| 558 | prompt: Union[str, List[str]] = None, |
| 559 | prompt_2: Optional[Union[str, List[str]]] = None, |
| 560 | height: Optional[int] = 512, |
| 561 | width: Optional[int] = 512, |
| 562 | num_inference_steps: int = 28, |
| 563 | timesteps: List[int] = None, |
| 564 | guidance_scale: float = 3.5, |
| 565 | num_images_per_prompt: Optional[int] = 1, |
| 566 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| 567 | latents: Optional[torch.FloatTensor] = None, |
| 568 | prompt_embeds: Optional[torch.FloatTensor] = None, |
| 569 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
| 570 | output_type: Optional[str] = "pil", |
| 571 | return_dict: bool = True, |
| 572 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, |
| 573 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
| 574 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
| 575 | max_sequence_length: int = 512, |
| 576 | # Condition Parameters (Optional) |
| 577 | main_adapter: Optional[List[str]] = None, |
| 578 | conditions: List[Condition] = [], |
| 579 | condition_scale: float = 1.0, |
| 580 | image_guidance_scale: float = 1.0, |
| 581 | transformer_kwargs: Optional[Dict[str, Any]] = {}, |
| 582 | kv_cache=False, |
| 583 | latent_mask=None, |
| 584 | **params: dict, |
| 585 | ): |
| 586 | self = pipeline |
| 587 | |
| 588 | height = height or self.default_sample_size * self.vae_scale_factor |
| 589 | width = width or self.default_sample_size * self.vae_scale_factor |
| 590 | |
| 591 | # Check inputs. Raise error if not correct |
| 592 | self.check_inputs( |
| 593 | prompt, |
| 594 | prompt_2, |
| 595 | height, |
| 596 | width, |
| 597 | prompt_embeds=prompt_embeds, |
| 598 | pooled_prompt_embeds=pooled_prompt_embeds, |
| 599 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, |
| 600 | max_sequence_length=max_sequence_length, |
| 601 | ) |
| 602 | |
| 603 | self._guidance_scale = guidance_scale |
| 604 | self._joint_attention_kwargs = joint_attention_kwargs |
| 605 | |
| 606 | # Define call parameters |
| 607 | if prompt is not None and isinstance(prompt, str): |
| 608 | batch_size = 1 |
| 609 | elif prompt is not None and isinstance(prompt, list): |
| 610 | batch_size = len(prompt) |
| 611 | else: |
| 612 | batch_size = prompt_embeds.shape[0] |
| 613 | |