MCPcopy
hub / github.com/XPixelGroup/DiffBIR / DPMSolver

Class DPMSolver

diffbir/sampler/k_diffusion.py:338–483  ·  view source on GitHub ↗

DPM-Solver. See https://arxiv.org/abs/2206.00927.

Source from the content-addressed store, hash-verified

336
337
338class DPMSolver(nn.Module):
339 """DPM-Solver. See https://arxiv.org/abs/2206.00927."""
340
341 def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None):
342 super().__init__()
343 self.model = model
344 self.extra_args = {} if extra_args is None else extra_args
345 self.eps_callback = eps_callback
346 self.info_callback = info_callback
347
348 def t(self, sigma):
349 return -sigma.log()
350
351 def sigma(self, t):
352 return t.neg().exp()
353
354 def eps(self, eps_cache, key, x, t, *args, **kwargs):
355 if key in eps_cache:
356 return eps_cache[key], eps_cache
357 sigma = self.sigma(t) * x.new_ones([x.shape[0]])
358 eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t)
359 if self.eps_callback is not None:
360 self.eps_callback()
361 return eps, {key: eps, **eps_cache}
362
363 def dpm_solver_1_step(self, x, t, t_next, eps_cache=None):
364 eps_cache = {} if eps_cache is None else eps_cache
365 h = t_next - t
366 eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
367 x_1 = x - self.sigma(t_next) * h.expm1() * eps
368 return x_1, eps_cache
369
370 def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None):
371 eps_cache = {} if eps_cache is None else eps_cache
372 h = t_next - t
373 eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
374 s1 = t + r1 * h
375 u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
376 eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
377 x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps)
378 return x_2, eps_cache
379
380 def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None):
381 eps_cache = {} if eps_cache is None else eps_cache
382 h = t_next - t
383 eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
384 s1 = t + r1 * h
385 s2 = t + r2 * h
386 u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
387 eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
388 u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps)
389 eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2)
390 x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps)
391 return x_3, eps_cache
392
393 def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
394 noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
395 if not t_end > t_start and eta:

Callers 2

sample_dpm_fastFunction · 0.85
sample_dpm_adaptiveFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected