Generate samples from the model. :param model: the model module. :param shape: the shape of the samples, (N, C, H, W). :param noise: if specified, the noise from the encoder to sample. Should be of the same shape as `shape`. :param
(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
)
| 421 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} |
| 422 | |
| 423 | def p_sample_loop( |
| 424 | self, |
| 425 | model, |
| 426 | shape, |
| 427 | noise=None, |
| 428 | clip_denoised=True, |
| 429 | denoised_fn=None, |
| 430 | cond_fn=None, |
| 431 | model_kwargs=None, |
| 432 | device=None, |
| 433 | progress=False, |
| 434 | ): |
| 435 | """ |
| 436 | Generate samples from the model. |
| 437 | :param model: the model module. |
| 438 | :param shape: the shape of the samples, (N, C, H, W). |
| 439 | :param noise: if specified, the noise from the encoder to sample. |
| 440 | Should be of the same shape as `shape`. |
| 441 | :param clip_denoised: if True, clip x_start predictions to [-1, 1]. |
| 442 | :param denoised_fn: if not None, a function which applies to the |
| 443 | x_start prediction before it is used to sample. |
| 444 | :param cond_fn: if not None, this is a gradient function that acts |
| 445 | similarly to the model. |
| 446 | :param model_kwargs: if not None, a dict of extra keyword arguments to |
| 447 | pass to the model. This can be used for conditioning. |
| 448 | :param device: if specified, the device to create the samples on. |
| 449 | If not specified, use a model parameter's device. |
| 450 | :param progress: if True, show a tqdm progress bar. |
| 451 | :return: a non-differentiable batch of samples. |
| 452 | """ |
| 453 | final = None |
| 454 | for sample in self.p_sample_loop_progressive( |
| 455 | model, |
| 456 | shape, |
| 457 | noise=noise, |
| 458 | clip_denoised=clip_denoised, |
| 459 | denoised_fn=denoised_fn, |
| 460 | cond_fn=cond_fn, |
| 461 | model_kwargs=model_kwargs, |
| 462 | device=device, |
| 463 | progress=progress, |
| 464 | ): |
| 465 | final = sample |
| 466 | return final["sample"] |
| 467 | |
| 468 | def p_sample_loop_progressive( |
| 469 | self, |
no test coverage detected