Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with the multistep DPMSolver. Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`int`)
(
self,
model_output: torch.Tensor,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator=None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True,
)
| 706 | |
| 707 | # Modified from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step |
| 708 | def step( |
| 709 | self, |
| 710 | model_output: torch.Tensor, |
| 711 | timestep: Union[int, torch.Tensor], |
| 712 | sample: torch.Tensor, |
| 713 | generator=None, |
| 714 | variance_noise: Optional[torch.Tensor] = None, |
| 715 | return_dict: bool = True, |
| 716 | ) -> Union[SchedulerOutput, Tuple]: |
| 717 | """ |
| 718 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with |
| 719 | the multistep DPMSolver. |
| 720 | Args: |
| 721 | model_output (`torch.Tensor`): |
| 722 | The direct output from learned diffusion model. |
| 723 | timestep (`int`): |
| 724 | The current discrete timestep in the diffusion chain. |
| 725 | sample (`torch.Tensor`): |
| 726 | A current instance of a sample created by the diffusion process. |
| 727 | generator (`torch.Generator`, *optional*): |
| 728 | A random number generator. |
| 729 | variance_noise (`torch.Tensor`): |
| 730 | Alternative to generating noise with `generator` by directly providing the noise for the variance |
| 731 | itself. Useful for methods such as [`LEdits++`]. |
| 732 | return_dict (`bool`): |
| 733 | Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. |
| 734 | Returns: |
| 735 | [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: |
| 736 | If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a |
| 737 | tuple is returned where the first element is the sample tensor. |
| 738 | """ |
| 739 | if self.num_inference_steps is None: |
| 740 | raise ValueError( |
| 741 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" |
| 742 | ) |
| 743 | |
| 744 | if self.step_index is None: |
| 745 | self._init_step_index(timestep) |
| 746 | |
| 747 | # Improve numerical stability for small number of steps |
| 748 | lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( |
| 749 | self.config.euler_at_final or |
| 750 | (self.config.lower_order_final and len(self.timesteps) < 15) or |
| 751 | self.config.final_sigmas_type == "zero") |
| 752 | lower_order_second = ((self.step_index == len(self.timesteps) - 2) and |
| 753 | self.config.lower_order_final and |
| 754 | len(self.timesteps) < 15) |
| 755 | |
| 756 | model_output = self.convert_model_output(model_output, sample=sample) |
| 757 | for i in range(self.config.solver_order - 1): |
| 758 | self.model_outputs[i] = self.model_outputs[i + 1] |
| 759 | self.model_outputs[-1] = model_output |
| 760 | |
| 761 | # Upcast to avoid precision issues when computing prev_sample |
| 762 | sample = sample.to(torch.float32) |
| 763 | if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++" |
| 764 | ] and variance_noise is None: |
| 765 | noise = randn_tensor( |
no test coverage detected