(i_, img_, xstartpreds_)
| 230 | xstartpreds_0 = tf.zeros([shape[0], num_recorded_xstartpred, *shape[1:]], dtype=tf.float32) # [B, N, H, W, C] |
| 231 | |
| 232 | def _loop_body(i_, img_, xstartpreds_): |
| 233 | # Sample p(x_{t-1} | x_t) as usual |
| 234 | sample, pred_xstart = self.p_sample( |
| 235 | denoise_fn=denoise_fn, x=img_, t=tf.fill([shape[0]], i_), noise_fn=noise_fn, return_pred_xstart=True) |
| 236 | assert sample.shape == pred_xstart.shape == shape |
| 237 | # Keep track of prediction of x0 |
| 238 | insert_mask = tf.equal(tf.floordiv(i_, include_xstartpred_freq), |
| 239 | tf.range(num_recorded_xstartpred, dtype=tf.int32)) |
| 240 | insert_mask = tf.reshape(tf.cast(insert_mask, dtype=tf.float32), |
| 241 | [1, num_recorded_xstartpred, *([1] * len(shape[1:]))]) # [1, N, 1, 1, 1] |
| 242 | new_xstartpreds = insert_mask * pred_xstart[:, None, ...] + (1. - insert_mask) * xstartpreds_ |
| 243 | return [i_ - 1, sample, new_xstartpreds] |
| 244 | |
| 245 | _, img_final, xstartpreds_final = tf.while_loop( |
| 246 | cond=lambda i_, img_, xstartpreds_: tf.greater_equal(i_, 0), |
nothing calls this directly
no test coverage detected