Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): Total number of the spacing of the time steps. device (`str` or `torch.device`, *optional*): The device t
(
self,
num_inference_steps: Union[int, None] = None,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[Union[float, None]] = None,
shift: Optional[Union[float, None]] = None,
)
| 226 | |
| 227 | # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps |
| 228 | def set_timesteps( |
| 229 | self, |
| 230 | num_inference_steps: Union[int, None] = None, |
| 231 | device: Union[str, torch.device] = None, |
| 232 | sigmas: Optional[List[float]] = None, |
| 233 | mu: Optional[Union[float, None]] = None, |
| 234 | shift: Optional[Union[float, None]] = None, |
| 235 | ): |
| 236 | """ |
| 237 | Sets the discrete timesteps used for the diffusion chain (to be run before inference). |
| 238 | Args: |
| 239 | num_inference_steps (`int`): |
| 240 | Total number of the spacing of the time steps. |
| 241 | device (`str` or `torch.device`, *optional*): |
| 242 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
| 243 | """ |
| 244 | |
| 245 | if self.config.use_dynamic_shifting and mu is None: |
| 246 | raise ValueError( |
| 247 | " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" |
| 248 | ) |
| 249 | |
| 250 | if sigmas is None: |
| 251 | sigmas = np.linspace(self.sigma_max, self.sigma_min, |
| 252 | num_inference_steps + |
| 253 | 1).copy()[:-1] # pyright: ignore |
| 254 | |
| 255 | if self.config.use_dynamic_shifting: |
| 256 | sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore |
| 257 | else: |
| 258 | if shift is None: |
| 259 | shift = self.config.shift |
| 260 | sigmas = shift * sigmas / (1 + |
| 261 | (shift - 1) * sigmas) # pyright: ignore |
| 262 | |
| 263 | if self.config.final_sigmas_type == "sigma_min": |
| 264 | sigma_last = ((1 - self.alphas_cumprod[0]) / |
| 265 | self.alphas_cumprod[0])**0.5 |
| 266 | elif self.config.final_sigmas_type == "zero": |
| 267 | sigma_last = 0 |
| 268 | else: |
| 269 | raise ValueError( |
| 270 | f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" |
| 271 | ) |
| 272 | |
| 273 | timesteps = sigmas * self.config.num_train_timesteps |
| 274 | sigmas = np.concatenate([sigmas, [sigma_last] |
| 275 | ]).astype(np.float32) # pyright: ignore |
| 276 | |
| 277 | self.sigmas = torch.from_numpy(sigmas) |
| 278 | self.timesteps = torch.from_numpy(timesteps).to( |
| 279 | device=device, dtype=torch.int64) |
| 280 | |
| 281 | self.num_inference_steps = len(timesteps) |
| 282 | |
| 283 | self.model_outputs = [ |
| 284 | None, |
| 285 | ] * self.config.solver_order |
no test coverage detected