DPM-Solver++ (stochastic).
(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2)
| 546 | |
| 547 | @torch.no_grad() |
| 548 | def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): |
| 549 | """DPM-Solver++ (stochastic).""" |
| 550 | sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() |
| 551 | noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler |
| 552 | extra_args = {} if extra_args is None else extra_args |
| 553 | s_in = x.new_ones([x.shape[0]]) |
| 554 | sigma_fn = lambda t: t.neg().exp() |
| 555 | t_fn = lambda sigma: sigma.log().neg() |
| 556 | |
| 557 | for i in trange(len(sigmas) - 1, disable=disable): |
| 558 | denoised = model(x, sigmas[i] * s_in, **extra_args) |
| 559 | if callback is not None: |
| 560 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) |
| 561 | if sigmas[i + 1] == 0: |
| 562 | # Euler method |
| 563 | d = to_d(x, sigmas[i], denoised) |
| 564 | dt = sigmas[i + 1] - sigmas[i] |
| 565 | x = x + d * dt |
| 566 | else: |
| 567 | # DPM-Solver++ |
| 568 | t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) |
| 569 | h = t_next - t |
| 570 | s = t + h * r |
| 571 | fac = 1 / (2 * r) |
| 572 | |
| 573 | # Step 1 |
| 574 | sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta) |
| 575 | s_ = t_fn(sd) |
| 576 | x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised |
| 577 | x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su |
| 578 | denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args) |
| 579 | |
| 580 | # Step 2 |
| 581 | sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta) |
| 582 | t_next_ = t_fn(sd) |
| 583 | denoised_d = (1 - fac) * denoised + fac * denoised_2 |
| 584 | x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d |
| 585 | x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su |
| 586 | return x |
| 587 | |
| 588 | |
| 589 | @torch.no_grad() |
nothing calls this directly
no test coverage detected