Generate samples
(self, denoise_fn, *, shape, noise_fn=tf.random_normal)
| 198 | return (sample, pred_xstart) if return_pred_xstart else sample |
| 199 | |
| 200 | def p_sample_loop(self, denoise_fn, *, shape, noise_fn=tf.random_normal): |
| 201 | """ |
| 202 | Generate samples |
| 203 | """ |
| 204 | assert isinstance(shape, (tuple, list)) |
| 205 | i_0 = tf.constant(self.num_timesteps - 1, dtype=tf.int32) |
| 206 | img_0 = noise_fn(shape=shape, dtype=tf.float32) |
| 207 | _, img_final = tf.while_loop( |
| 208 | cond=lambda i_, _: tf.greater_equal(i_, 0), |
| 209 | body=lambda i_, img_: [ |
| 210 | i_ - 1, |
| 211 | self.p_sample( |
| 212 | denoise_fn=denoise_fn, x=img_, t=tf.fill([shape[0]], i_), noise_fn=noise_fn, return_pred_xstart=False) |
| 213 | ], |
| 214 | loop_vars=[i_0, img_0], |
| 215 | shape_invariants=[i_0.shape, img_0.shape], |
| 216 | back_prop=False |
| 217 | ) |
| 218 | assert img_final.shape == shape |
| 219 | return img_final |
| 220 | |
| 221 | def p_sample_loop_progressive(self, denoise_fn, *, shape, noise_fn=tf.random_normal, include_xstartpred_freq=50): |
| 222 | """ |