MCPcopy Index your code
hub / github.com/Stability-AI/generative-models / DPMPP2MSampler

Class DPMPP2MSampler

sgm/modules/diffusionmodules/sampling.py:290–365  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

288
289
290class DPMPP2MSampler(BaseDiffusionSampler):
291 def get_variables(self, sigma, next_sigma, previous_sigma=None):
292 t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
293 h = t_next - t
294
295 if previous_sigma is not None:
296 h_last = t - to_neg_log_sigma(previous_sigma)
297 r = h_last / h
298 return h, r, t, t_next
299 else:
300 return h, None, t, t_next
301
302 def get_mult(self, h, r, t, t_next, previous_sigma):
303 mult1 = to_sigma(t_next) / to_sigma(t)
304 mult2 = (-h).expm1()
305
306 if previous_sigma is not None:
307 mult3 = 1 + 1 / (2 * r)
308 mult4 = 1 / (2 * r)
309 return mult1, mult2, mult3, mult4
310 else:
311 return mult1, mult2
312
313 def sampler_step(
314 self,
315 old_denoised,
316 previous_sigma,
317 sigma,
318 next_sigma,
319 denoiser,
320 x,
321 cond,
322 uc=None,
323 ):
324 denoised = self.denoise(x, denoiser, sigma, cond, uc)
325
326 h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
327 mult = [
328 append_dims(mult, x.ndim)
329 for mult in self.get_mult(h, r, t, t_next, previous_sigma)
330 ]
331
332 x_standard = mult[0] * x - mult[1] * denoised
333 if old_denoised is None or torch.sum(next_sigma) < 1e-14:
334 # Save a network evaluation if all noise levels are 0 or on the first step
335 return x_standard, denoised
336 else:
337 denoised_d = mult[2] * denoised - mult[3] * old_denoised
338 x_advanced = mult[0] * x - mult[1] * denoised_d
339
340 # apply correction if noise level is not 0 and not first step
341 x = torch.where(
342 append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
343 )
344
345 return x, denoised
346
347 def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):

Callers 2

get_samplerFunction · 0.90
get_sampler_configFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected