| 288 | |
| 289 | |
| 290 | class DPMPP2MSampler(BaseDiffusionSampler): |
| 291 | def get_variables(self, sigma, next_sigma, previous_sigma=None): |
| 292 | t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)] |
| 293 | h = t_next - t |
| 294 | |
| 295 | if previous_sigma is not None: |
| 296 | h_last = t - to_neg_log_sigma(previous_sigma) |
| 297 | r = h_last / h |
| 298 | return h, r, t, t_next |
| 299 | else: |
| 300 | return h, None, t, t_next |
| 301 | |
| 302 | def get_mult(self, h, r, t, t_next, previous_sigma): |
| 303 | mult1 = to_sigma(t_next) / to_sigma(t) |
| 304 | mult2 = (-h).expm1() |
| 305 | |
| 306 | if previous_sigma is not None: |
| 307 | mult3 = 1 + 1 / (2 * r) |
| 308 | mult4 = 1 / (2 * r) |
| 309 | return mult1, mult2, mult3, mult4 |
| 310 | else: |
| 311 | return mult1, mult2 |
| 312 | |
| 313 | def sampler_step( |
| 314 | self, |
| 315 | old_denoised, |
| 316 | previous_sigma, |
| 317 | sigma, |
| 318 | next_sigma, |
| 319 | denoiser, |
| 320 | x, |
| 321 | cond, |
| 322 | uc=None, |
| 323 | ): |
| 324 | denoised = self.denoise(x, denoiser, sigma, cond, uc) |
| 325 | |
| 326 | h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) |
| 327 | mult = [ |
| 328 | append_dims(mult, x.ndim) |
| 329 | for mult in self.get_mult(h, r, t, t_next, previous_sigma) |
| 330 | ] |
| 331 | |
| 332 | x_standard = mult[0] * x - mult[1] * denoised |
| 333 | if old_denoised is None or torch.sum(next_sigma) < 1e-14: |
| 334 | # Save a network evaluation if all noise levels are 0 or on the first step |
| 335 | return x_standard, denoised |
| 336 | else: |
| 337 | denoised_d = mult[2] * denoised - mult[3] * old_denoised |
| 338 | x_advanced = mult[0] * x - mult[1] * denoised_d |
| 339 | |
| 340 | # apply correction if noise level is not 0 and not first step |
| 341 | x = torch.where( |
| 342 | append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard |
| 343 | ) |
| 344 | |
| 345 | return x, denoised |
| 346 | |
| 347 | def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): |
no outgoing calls
no test coverage detected