Sample x_{t-1} from the model at the given timestep. :param model: the model to sample from. :param x: the current tensor at x_{t-1}. :param t: the value of t, starting at 0 for the first diffusion step. :param clip_denoised: if True, clip the x_start pr
(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
)
| 378 | return out |
| 379 | |
| 380 | def p_sample( |
| 381 | self, |
| 382 | model, |
| 383 | x, |
| 384 | t, |
| 385 | clip_denoised=True, |
| 386 | denoised_fn=None, |
| 387 | cond_fn=None, |
| 388 | model_kwargs=None, |
| 389 | ): |
| 390 | """ |
| 391 | Sample x_{t-1} from the model at the given timestep. |
| 392 | :param model: the model to sample from. |
| 393 | :param x: the current tensor at x_{t-1}. |
| 394 | :param t: the value of t, starting at 0 for the first diffusion step. |
| 395 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. |
| 396 | :param denoised_fn: if not None, a function which applies to the |
| 397 | x_start prediction before it is used to sample. |
| 398 | :param cond_fn: if not None, this is a gradient function that acts |
| 399 | similarly to the model. |
| 400 | :param model_kwargs: if not None, a dict of extra keyword arguments to |
| 401 | pass to the model. This can be used for conditioning. |
| 402 | :return: a dict containing the following keys: |
| 403 | - 'sample': a random sample from the model. |
| 404 | - 'pred_xstart': a prediction of x_0. |
| 405 | """ |
| 406 | out = self.p_mean_variance( |
| 407 | model, |
| 408 | x, |
| 409 | t, |
| 410 | clip_denoised=clip_denoised, |
| 411 | denoised_fn=denoised_fn, |
| 412 | model_kwargs=model_kwargs, |
| 413 | ) |
| 414 | noise = th.randn_like(x) |
| 415 | nonzero_mask = ( |
| 416 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) |
| 417 | ) # no noise when t == 0 |
| 418 | if cond_fn is not None: |
| 419 | out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) |
| 420 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise |
| 421 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} |
| 422 | |
| 423 | def p_sample_loop( |
| 424 | self, |
no test coverage detected