MCPcopy
hub / github.com/XPixelGroup/DiffBIR / dpm_solver_fast

Method dpm_solver_fast

diffbir/sampler/k_diffusion.py:393–430  ·  view source on GitHub ↗
(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None)

Source from the content-addressed store, hash-verified

391 return x_3, eps_cache
392
393 def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
394 noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
395 if not t_end > t_start and eta:
396 raise ValueError('eta must be 0 for reverse sampling')
397
398 m = math.floor(nfe / 3) + 1
399 ts = torch.linspace(t_start, t_end, m + 1, device=x.device)
400
401 if nfe % 3 == 0:
402 orders = [3] * (m - 2) + [2, 1]
403 else:
404 orders = [3] * (m - 1) + [nfe % 3]
405
406 for i in range(len(orders)):
407 eps_cache = {}
408 t, t_next = ts[i], ts[i + 1]
409 if eta:
410 sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
411 t_next_ = torch.minimum(t_end, self.t(sd))
412 su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5
413 else:
414 t_next_, su = t_next, 0.
415
416 eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
417 denoised = x - self.sigma(t) * eps
418 if self.info_callback is not None:
419 self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised})
420
421 if orders[i] == 1:
422 x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache)
423 elif orders[i] == 2:
424 x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache)
425 else:
426 x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache)
427
428 x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next))
429
430 return x
431
432 def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
433 noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler

Callers 1

sample_dpm_fastFunction · 0.95

Calls 8

sigmaMethod · 0.95
tMethod · 0.95
epsMethod · 0.95
dpm_solver_1_stepMethod · 0.95
dpm_solver_2_stepMethod · 0.95
dpm_solver_3_stepMethod · 0.95
default_noise_samplerFunction · 0.85
get_ancestral_stepFunction · 0.85

Tested by

no test coverage detected