(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None)
| 430 | return x |
| 431 | |
| 432 | def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None): |
| 433 | noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler |
| 434 | if order not in {2, 3}: |
| 435 | raise ValueError('order should be 2 or 3') |
| 436 | forward = t_end > t_start |
| 437 | if not forward and eta: |
| 438 | raise ValueError('eta must be 0 for reverse sampling') |
| 439 | h_init = abs(h_init) * (1 if forward else -1) |
| 440 | atol = torch.tensor(atol) |
| 441 | rtol = torch.tensor(rtol) |
| 442 | s = t_start |
| 443 | x_prev = x |
| 444 | accept = True |
| 445 | pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety) |
| 446 | info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0} |
| 447 | |
| 448 | while s < t_end - 1e-5 if forward else s > t_end + 1e-5: |
| 449 | eps_cache = {} |
| 450 | t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h) |
| 451 | if eta: |
| 452 | sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta) |
| 453 | t_ = torch.minimum(t_end, self.t(sd)) |
| 454 | su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5 |
| 455 | else: |
| 456 | t_, su = t, 0. |
| 457 | |
| 458 | eps, eps_cache = self.eps(eps_cache, 'eps', x, s) |
| 459 | denoised = x - self.sigma(s) * eps |
| 460 | |
| 461 | if order == 2: |
| 462 | x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache) |
| 463 | x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache) |
| 464 | else: |
| 465 | x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache) |
| 466 | x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache) |
| 467 | delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs())) |
| 468 | error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5 |
| 469 | accept = pid.propose_step(error) |
| 470 | if accept: |
| 471 | x_prev = x_low |
| 472 | x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t)) |
| 473 | s = t |
| 474 | info['n_accept'] += 1 |
| 475 | else: |
| 476 | info['n_reject'] += 1 |
| 477 | info['nfe'] += order |
| 478 | info['steps'] += 1 |
| 479 | |
| 480 | if self.info_callback is not None: |
| 481 | self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info}) |
| 482 | |
| 483 | return x, info |
| 484 | |
| 485 | |
| 486 | @torch.no_grad() |
no test coverage detected