(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None)
| 391 | return x_3, eps_cache |
| 392 | |
| 393 | def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None): |
| 394 | noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler |
| 395 | if not t_end > t_start and eta: |
| 396 | raise ValueError('eta must be 0 for reverse sampling') |
| 397 | |
| 398 | m = math.floor(nfe / 3) + 1 |
| 399 | ts = torch.linspace(t_start, t_end, m + 1, device=x.device) |
| 400 | |
| 401 | if nfe % 3 == 0: |
| 402 | orders = [3] * (m - 2) + [2, 1] |
| 403 | else: |
| 404 | orders = [3] * (m - 1) + [nfe % 3] |
| 405 | |
| 406 | for i in range(len(orders)): |
| 407 | eps_cache = {} |
| 408 | t, t_next = ts[i], ts[i + 1] |
| 409 | if eta: |
| 410 | sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta) |
| 411 | t_next_ = torch.minimum(t_end, self.t(sd)) |
| 412 | su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5 |
| 413 | else: |
| 414 | t_next_, su = t_next, 0. |
| 415 | |
| 416 | eps, eps_cache = self.eps(eps_cache, 'eps', x, t) |
| 417 | denoised = x - self.sigma(t) * eps |
| 418 | if self.info_callback is not None: |
| 419 | self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised}) |
| 420 | |
| 421 | if orders[i] == 1: |
| 422 | x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache) |
| 423 | elif orders[i] == 2: |
| 424 | x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache) |
| 425 | else: |
| 426 | x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache) |
| 427 | |
| 428 | x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next)) |
| 429 | |
| 430 | return x |
| 431 | |
| 432 | def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None): |
| 433 | noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler |
no test coverage detected