MCPcopy
hub / github.com/Vchitect/Latte / ddim_sample

Method ddim_sample

diffusion/gaussian_diffusion.py:517–564  ·  view source on GitHub ↗

Sample x_{t-1} from the model using DDIM. Same usage as p_sample().

(
        self,
        model,
        x,
        t,
        clip_denoised=True,
        denoised_fn=None,
        cond_fn=None,
        model_kwargs=None,
        eta=0.0,
    )

Source from the content-addressed store, hash-verified

515 img = out["sample"]
516
517 def ddim_sample(
518 self,
519 model,
520 x,
521 t,
522 clip_denoised=True,
523 denoised_fn=None,
524 cond_fn=None,
525 model_kwargs=None,
526 eta=0.0,
527 ):
528 """
529 Sample x_{t-1} from the model using DDIM.
530 Same usage as p_sample().
531 """
532 out = self.p_mean_variance(
533 model,
534 x,
535 t,
536 clip_denoised=clip_denoised,
537 denoised_fn=denoised_fn,
538 model_kwargs=model_kwargs,
539 )
540 if cond_fn is not None:
541 out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
542
543 # Usually our model outputs epsilon, but we re-derive it
544 # in case we used x_start or x_prev prediction.
545 eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
546
547 alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
548 alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
549 sigma = (
550 eta
551 * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
552 * th.sqrt(1 - alpha_bar / alpha_bar_prev)
553 )
554 # Equation 12.
555 noise = th.randn_like(x)
556 mean_pred = (
557 out["pred_xstart"] * th.sqrt(alpha_bar_prev)
558 + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
559 )
560 nonzero_mask = (
561 (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
562 ) # no noise when t == 0
563 sample = mean_pred + nonzero_mask * sigma * noise
564 return {"sample": sample, "pred_xstart": out["pred_xstart"]}
565
566 def ddim_reverse_sample(
567 self,

Callers 1

Calls 4

p_mean_varianceMethod · 0.95
condition_scoreMethod · 0.95
_extract_into_tensorFunction · 0.85

Tested by

no test coverage detected