A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).
(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.)
| 191 | |
| 192 | @torch.no_grad() |
| 193 | def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): |
| 194 | """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" |
| 195 | extra_args = {} if extra_args is None else extra_args |
| 196 | s_in = x.new_ones([x.shape[0]]) |
| 197 | for i in trange(len(sigmas) - 1, disable=disable): |
| 198 | gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. |
| 199 | eps = torch.randn_like(x) * s_noise |
| 200 | sigma_hat = sigmas[i] * (gamma + 1) |
| 201 | if gamma > 0: |
| 202 | x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 |
| 203 | denoised = model(x, sigma_hat * s_in, **extra_args) |
| 204 | d = to_d(x, sigma_hat, denoised) |
| 205 | if callback is not None: |
| 206 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) |
| 207 | if sigmas[i + 1] == 0: |
| 208 | # Euler method |
| 209 | dt = sigmas[i + 1] - sigma_hat |
| 210 | x = x + d * dt |
| 211 | else: |
| 212 | # DPM-Solver-2 |
| 213 | sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp() |
| 214 | dt_1 = sigma_mid - sigma_hat |
| 215 | dt_2 = sigmas[i + 1] - sigma_hat |
| 216 | x_2 = x + d * dt_1 |
| 217 | denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) |
| 218 | d_2 = to_d(x_2, sigma_mid, denoised_2) |
| 219 | x = x + d_2 * dt_2 |
| 220 | return x |
| 221 | |
| 222 | |
| 223 | @torch.no_grad() |