DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927.
(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None)
| 485 | |
| 486 | @torch.no_grad() |
| 487 | def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None): |
| 488 | """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927.""" |
| 489 | if sigma_min <= 0 or sigma_max <= 0: |
| 490 | raise ValueError('sigma_min and sigma_max must not be 0') |
| 491 | with tqdm(total=n, disable=disable) as pbar: |
| 492 | dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update) |
| 493 | if callback is not None: |
| 494 | dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info}) |
| 495 | return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler) |
| 496 | |
| 497 | |
| 498 | @torch.no_grad() |
nothing calls this directly
no test coverage detected