MCPcopy
hub / github.com/XPixelGroup/DiffBIR / dpm_solver_adaptive

Method dpm_solver_adaptive

diffbir/sampler/k_diffusion.py:432–483  ·  view source on GitHub ↗
(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None)

Source from the content-addressed store, hash-verified

430 return x
431
432 def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
433 noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
434 if order not in {2, 3}:
435 raise ValueError('order should be 2 or 3')
436 forward = t_end > t_start
437 if not forward and eta:
438 raise ValueError('eta must be 0 for reverse sampling')
439 h_init = abs(h_init) * (1 if forward else -1)
440 atol = torch.tensor(atol)
441 rtol = torch.tensor(rtol)
442 s = t_start
443 x_prev = x
444 accept = True
445 pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety)
446 info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0}
447
448 while s < t_end - 1e-5 if forward else s > t_end + 1e-5:
449 eps_cache = {}
450 t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h)
451 if eta:
452 sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta)
453 t_ = torch.minimum(t_end, self.t(sd))
454 su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5
455 else:
456 t_, su = t, 0.
457
458 eps, eps_cache = self.eps(eps_cache, 'eps', x, s)
459 denoised = x - self.sigma(s) * eps
460
461 if order == 2:
462 x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache)
463 x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache)
464 else:
465 x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache)
466 x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache)
467 delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs()))
468 error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5
469 accept = pid.propose_step(error)
470 if accept:
471 x_prev = x_low
472 x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t))
473 s = t
474 info['n_accept'] += 1
475 else:
476 info['n_reject'] += 1
477 info['nfe'] += order
478 info['steps'] += 1
479
480 if self.info_callback is not None:
481 self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info})
482
483 return x, info
484
485
486@torch.no_grad()

Callers 1

sample_dpm_adaptiveFunction · 0.95

Calls 10

sigmaMethod · 0.95
tMethod · 0.95
epsMethod · 0.95
dpm_solver_1_stepMethod · 0.95
dpm_solver_2_stepMethod · 0.95
dpm_solver_3_stepMethod · 0.95
propose_stepMethod · 0.95
default_noise_samplerFunction · 0.85
get_ancestral_stepFunction · 0.85

Tested by

no test coverage detected