MCPcopy Index your code
hub / github.com/Stability-AI/generative-models / HeunEDMSampler

Class HeunEDMSampler

sgm/modules/diffusionmodules/sampling.py:221–237  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

219
220
221class HeunEDMSampler(EDMSampler):
222 def possible_correction_step(
223 self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
224 ):
225 if torch.sum(next_sigma) < 1e-14:
226 # Save a network evaluation if all noise levels are 0
227 return euler_step
228 else:
229 denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc)
230 d_new = to_d(euler_step, next_sigma, denoised)
231 d_prime = (d + d_new) / 2.0
232
233 # apply correction if noise level is not 0
234 x = torch.where(
235 append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step
236 )
237 return x
238
239
240class EulerAncestralSampler(AncestralSampler):

Callers 2

get_samplerFunction · 0.90
get_sampler_configFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected