| 168 | |
| 169 | |
| 170 | def get_scheduler(name: str, config: dict = {}): |
| 171 | is_karras = name.startswith("k_") |
| 172 | if is_karras: |
| 173 | # strip the k_ prefix and add the karras sigma flag to config |
| 174 | name = name.lstrip("k_") |
| 175 | config["use_karras_sigmas"] = True |
| 176 | |
| 177 | if name == DiffusionScheduler.ddim: |
| 178 | sched_class = DDIMScheduler |
| 179 | elif name == DiffusionScheduler.pndm: |
| 180 | sched_class = PNDMScheduler |
| 181 | elif name == DiffusionScheduler.heun: |
| 182 | sched_class = HeunDiscreteScheduler |
| 183 | elif name == DiffusionScheduler.unipc: |
| 184 | sched_class = UniPCMultistepScheduler |
| 185 | elif name == DiffusionScheduler.euler: |
| 186 | sched_class = EulerDiscreteScheduler |
| 187 | elif name == DiffusionScheduler.euler_a: |
| 188 | sched_class = EulerAncestralDiscreteScheduler |
| 189 | elif name == DiffusionScheduler.lms: |
| 190 | sched_class = LMSDiscreteScheduler |
| 191 | elif name == DiffusionScheduler.dpm_2: |
| 192 | # Equivalent to DPM2 in K-Diffusion |
| 193 | sched_class = KDPM2DiscreteScheduler |
| 194 | elif name == DiffusionScheduler.dpm_2_a: |
| 195 | # Equivalent to `DPM2 a`` in K-Diffusion |
| 196 | sched_class = KDPM2AncestralDiscreteScheduler |
| 197 | elif name == DiffusionScheduler.dpmpp_2m: |
| 198 | # Equivalent to `DPM++ 2M` in K-Diffusion |
| 199 | sched_class = DPMSolverMultistepScheduler |
| 200 | config["algorithm_type"] = "dpmsolver++" |
| 201 | config["solver_order"] = 2 |
| 202 | elif name == DiffusionScheduler.dpmpp_sde: |
| 203 | # Equivalent to `DPM++ SDE` in K-Diffusion |
| 204 | sched_class = DPMSolverSinglestepScheduler |
| 205 | elif name == DiffusionScheduler.dpmpp_2m_sde: |
| 206 | # Equivalent to `DPM++ 2M SDE` in K-Diffusion |
| 207 | sched_class = DPMSolverMultistepScheduler |
| 208 | config["algorithm_type"] = "sde-dpmsolver++" |
| 209 | else: |
| 210 | raise ValueError(f"Invalid scheduler '{'k_' if is_karras else ''}{name}'") |
| 211 | |
| 212 | return sched_class.from_config(config) |
| 213 | |
| 214 | |
| 215 | # Implement the BackendServicer class with the service methods |