MCPcopy
hub / github.com/Vchitect/Latte / p_sample

Method p_sample

diffusion/gaussian_diffusion.py:380–421  ·  view source on GitHub ↗

Sample x_{t-1} from the model at the given timestep. :param model: the model to sample from. :param x: the current tensor at x_{t-1}. :param t: the value of t, starting at 0 for the first diffusion step. :param clip_denoised: if True, clip the x_start pr

(
        self,
        model,
        x,
        t,
        clip_denoised=True,
        denoised_fn=None,
        cond_fn=None,
        model_kwargs=None,
    )

Source from the content-addressed store, hash-verified

378 return out
379
380 def p_sample(
381 self,
382 model,
383 x,
384 t,
385 clip_denoised=True,
386 denoised_fn=None,
387 cond_fn=None,
388 model_kwargs=None,
389 ):
390 """
391 Sample x_{t-1} from the model at the given timestep.
392 :param model: the model to sample from.
393 :param x: the current tensor at x_{t-1}.
394 :param t: the value of t, starting at 0 for the first diffusion step.
395 :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
396 :param denoised_fn: if not None, a function which applies to the
397 x_start prediction before it is used to sample.
398 :param cond_fn: if not None, this is a gradient function that acts
399 similarly to the model.
400 :param model_kwargs: if not None, a dict of extra keyword arguments to
401 pass to the model. This can be used for conditioning.
402 :return: a dict containing the following keys:
403 - 'sample': a random sample from the model.
404 - 'pred_xstart': a prediction of x_0.
405 """
406 out = self.p_mean_variance(
407 model,
408 x,
409 t,
410 clip_denoised=clip_denoised,
411 denoised_fn=denoised_fn,
412 model_kwargs=model_kwargs,
413 )
414 noise = th.randn_like(x)
415 nonzero_mask = (
416 (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
417 ) # no noise when t == 0
418 if cond_fn is not None:
419 out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
420 sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
421 return {"sample": sample, "pred_xstart": out["pred_xstart"]}
422
423 def p_sample_loop(
424 self,

Callers 1

Calls 2

p_mean_varianceMethod · 0.95
condition_meanMethod · 0.95

Tested by

no test coverage detected