Sample from the model
(self, denoise_fn, *, x, t, noise_fn, clip_denoised=True, return_pred_xstart: bool)
| 184 | # === Sampling === |
| 185 | |
| 186 | def p_sample(self, denoise_fn, *, x, t, noise_fn, clip_denoised=True, return_pred_xstart: bool): |
| 187 | """ |
| 188 | Sample from the model |
| 189 | """ |
| 190 | model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( |
| 191 | denoise_fn, x=x, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) |
| 192 | noise = noise_fn(shape=x.shape, dtype=x.dtype) |
| 193 | assert noise.shape == x.shape |
| 194 | # no noise when t == 0 |
| 195 | nonzero_mask = tf.reshape(1 - tf.cast(tf.equal(t, 0), tf.float32), [x.shape[0]] + [1] * (len(x.shape) - 1)) |
| 196 | sample = model_mean + nonzero_mask * tf.exp(0.5 * model_log_variance) * noise |
| 197 | assert sample.shape == pred_xstart.shape |
| 198 | return (sample, pred_xstart) if return_pred_xstart else sample |
| 199 | |
| 200 | def p_sample_loop(self, denoise_fn, *, shape, noise_fn=tf.random_normal): |
| 201 | """ |
no test coverage detected