(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4)
| 264 | |
| 265 | @torch.no_grad() |
| 266 | def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4): |
| 267 | extra_args = {} if extra_args is None else extra_args |
| 268 | s_in = x.new_ones([x.shape[0]]) |
| 269 | sigmas_cpu = sigmas.detach().cpu().numpy() |
| 270 | ds = [] |
| 271 | for i in trange(len(sigmas) - 1, disable=disable): |
| 272 | denoised = model(x, sigmas[i] * s_in, **extra_args) |
| 273 | d = to_d(x, sigmas[i], denoised) |
| 274 | ds.append(d) |
| 275 | if len(ds) > order: |
| 276 | ds.pop(0) |
| 277 | if callback is not None: |
| 278 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) |
| 279 | cur_order = min(i + 1, order) |
| 280 | coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)] |
| 281 | x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) |
| 282 | return x |
| 283 | |
| 284 | |
| 285 | # @torch.no_grad() |
nothing calls this directly
no test coverage detected