(y, x)
| 97 | batch_ids = tf.tile(tf.expand_dims(tf.expand_dims(tf.range(B), axis=-1), axis=-1), [1, h, w]) |
| 98 | |
| 99 | def get(y, x): |
| 100 | idx = tf.reshape(batch_ids * h * w + y * w + x, [-1]) |
| 101 | idx = tf.cast(idx, tf.int32) |
| 102 | return tf.gather(img_flat, idx) |
| 103 | |
| 104 | val = tf.zeros_like(alpha) |
| 105 | val += (1 - alpha) * (1 - beta) * tf.reshape(get(yT, xL), [-1, h, w, c]) |