MCPcopy
hub / github.com/XPixelGroup/DiffBIR / sample_dpmpp_sde

Function sample_dpmpp_sde

diffbir/sampler/k_diffusion.py:548–586  ·  view source on GitHub ↗

DPM-Solver++ (stochastic).

(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2)

Source from the content-addressed store, hash-verified

546
547@torch.no_grad()
548def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
549 """DPM-Solver++ (stochastic)."""
550 sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
551 noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
552 extra_args = {} if extra_args is None else extra_args
553 s_in = x.new_ones([x.shape[0]])
554 sigma_fn = lambda t: t.neg().exp()
555 t_fn = lambda sigma: sigma.log().neg()
556
557 for i in trange(len(sigmas) - 1, disable=disable):
558 denoised = model(x, sigmas[i] * s_in, **extra_args)
559 if callback is not None:
560 callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
561 if sigmas[i + 1] == 0:
562 # Euler method
563 d = to_d(x, sigmas[i], denoised)
564 dt = sigmas[i + 1] - sigmas[i]
565 x = x + d * dt
566 else:
567 # DPM-Solver++
568 t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
569 h = t_next - t
570 s = t + h * r
571 fac = 1 / (2 * r)
572
573 # Step 1
574 sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
575 s_ = t_fn(sd)
576 x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
577 x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
578 denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
579
580 # Step 2
581 sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
582 t_next_ = t_fn(sd)
583 denoised_d = (1 - fac) * denoised + fac * denoised_2
584 x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
585 x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
586 return x
587
588
589@torch.no_grad()

Callers

nothing calls this directly

Calls 3

to_dFunction · 0.85
get_ancestral_stepFunction · 0.85

Tested by

no test coverage detected