MCPcopy Index your code
hub / github.com/XPixelGroup/DiffBIR / sample_dpm_2

Function sample_dpm_2

diffbir/sampler/k_diffusion.py:193–220  ·  view source on GitHub ↗

A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).

(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.)

Source from the content-addressed store, hash-verified

191
192@torch.no_grad()
193def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
194 """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
195 extra_args = {} if extra_args is None else extra_args
196 s_in = x.new_ones([x.shape[0]])
197 for i in trange(len(sigmas) - 1, disable=disable):
198 gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
199 eps = torch.randn_like(x) * s_noise
200 sigma_hat = sigmas[i] * (gamma + 1)
201 if gamma > 0:
202 x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
203 denoised = model(x, sigma_hat * s_in, **extra_args)
204 d = to_d(x, sigma_hat, denoised)
205 if callback is not None:
206 callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
207 if sigmas[i + 1] == 0:
208 # Euler method
209 dt = sigmas[i + 1] - sigma_hat
210 x = x + d * dt
211 else:
212 # DPM-Solver-2
213 sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
214 dt_1 = sigma_mid - sigma_hat
215 dt_2 = sigmas[i + 1] - sigma_hat
216 x_2 = x + d * dt_1
217 denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
218 d_2 = to_d(x_2, sigma_mid, denoised_2)
219 x = x + d_2 * dt_2
220 return x
221
222
223@torch.no_grad()

Callers

nothing calls this directly

Calls 1

to_dFunction · 0.85

Tested by

no test coverage detected