DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927.
(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False)
| 497 | |
| 498 | @torch.no_grad() |
| 499 | def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False): |
| 500 | """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927.""" |
| 501 | if sigma_min <= 0 or sigma_max <= 0: |
| 502 | raise ValueError('sigma_min and sigma_max must not be 0') |
| 503 | with tqdm(disable=disable) as pbar: |
| 504 | dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update) |
| 505 | if callback is not None: |
| 506 | dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info}) |
| 507 | x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler) |
| 508 | if return_info: |
| 509 | return x, info |
| 510 | return x |
| 511 | |
| 512 | |
| 513 | @torch.no_grad() |
nothing calls this directly
no test coverage detected