DPM-Solver. See https://arxiv.org/abs/2206.00927.
| 336 | |
| 337 | |
| 338 | class DPMSolver(nn.Module): |
| 339 | """DPM-Solver. See https://arxiv.org/abs/2206.00927.""" |
| 340 | |
| 341 | def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None): |
| 342 | super().__init__() |
| 343 | self.model = model |
| 344 | self.extra_args = {} if extra_args is None else extra_args |
| 345 | self.eps_callback = eps_callback |
| 346 | self.info_callback = info_callback |
| 347 | |
| 348 | def t(self, sigma): |
| 349 | return -sigma.log() |
| 350 | |
| 351 | def sigma(self, t): |
| 352 | return t.neg().exp() |
| 353 | |
| 354 | def eps(self, eps_cache, key, x, t, *args, **kwargs): |
| 355 | if key in eps_cache: |
| 356 | return eps_cache[key], eps_cache |
| 357 | sigma = self.sigma(t) * x.new_ones([x.shape[0]]) |
| 358 | eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t) |
| 359 | if self.eps_callback is not None: |
| 360 | self.eps_callback() |
| 361 | return eps, {key: eps, **eps_cache} |
| 362 | |
| 363 | def dpm_solver_1_step(self, x, t, t_next, eps_cache=None): |
| 364 | eps_cache = {} if eps_cache is None else eps_cache |
| 365 | h = t_next - t |
| 366 | eps, eps_cache = self.eps(eps_cache, 'eps', x, t) |
| 367 | x_1 = x - self.sigma(t_next) * h.expm1() * eps |
| 368 | return x_1, eps_cache |
| 369 | |
| 370 | def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None): |
| 371 | eps_cache = {} if eps_cache is None else eps_cache |
| 372 | h = t_next - t |
| 373 | eps, eps_cache = self.eps(eps_cache, 'eps', x, t) |
| 374 | s1 = t + r1 * h |
| 375 | u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps |
| 376 | eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1) |
| 377 | x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps) |
| 378 | return x_2, eps_cache |
| 379 | |
| 380 | def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None): |
| 381 | eps_cache = {} if eps_cache is None else eps_cache |
| 382 | h = t_next - t |
| 383 | eps, eps_cache = self.eps(eps_cache, 'eps', x, t) |
| 384 | s1 = t + r1 * h |
| 385 | s2 = t + r2 * h |
| 386 | u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps |
| 387 | eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1) |
| 388 | u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps) |
| 389 | eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2) |
| 390 | x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps) |
| 391 | return x_3, eps_cache |
| 392 | |
| 393 | def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None): |
| 394 | noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler |
| 395 | if not t_end > t_start and eta: |
no outgoing calls
no test coverage detected