Apply binary operation `op` to every pair in tensors `a` and `b`. :param op: binary operator on tensors, e.g. tf.add, tf.substract
(
op: Callable[[tf.Tensor, tf.Tensor], tf.Tensor], a: tf.Tensor, b: tf.Tensor
)
| 85 | "return: [a_shape..., b_shape...]", |
| 86 | ) |
| 87 | def broadcasting_elementwise( |
| 88 | op: Callable[[tf.Tensor, tf.Tensor], tf.Tensor], a: tf.Tensor, b: tf.Tensor |
| 89 | ) -> tf.Tensor: |
| 90 | """ |
| 91 | Apply binary operation `op` to every pair in tensors `a` and `b`. |
| 92 | |
| 93 | :param op: binary operator on tensors, e.g. tf.add, tf.substract |
| 94 | """ |
| 95 | flatres = op(tf.reshape(a, [-1, 1]), tf.reshape(b, [1, -1])) |
| 96 | return tf.reshape(flatres, tf.concat([tf.shape(a), tf.shape(b)], 0)) |
| 97 | |
| 98 | |
| 99 | @check_shapes( |
no test coverage detected
searching dependent graphs…