Returns (X - X2ᵀ)
(X: tf.Tensor, X2: Optional[tf.Tensor])
| 129 | "return: [batch..., N, N, D] if X2 is None", |
| 130 | ) |
| 131 | def difference_matrix(X: tf.Tensor, X2: Optional[tf.Tensor]) -> tf.Tensor: |
| 132 | """ |
| 133 | Returns (X - X2ᵀ) |
| 134 | """ |
| 135 | if X2 is None: |
| 136 | X2 = X |
| 137 | diff = X[..., :, tf.newaxis, :] - X2[..., tf.newaxis, :, :] |
| 138 | return diff |
| 139 | Xshape = tf.shape(X) |
| 140 | X2shape = tf.shape(X2) |
| 141 | X = tf.reshape(X, (-1, Xshape[-1])) |
| 142 | X2 = tf.reshape(X2, (-1, X2shape[-1])) |
| 143 | diff = X[:, tf.newaxis, :] - X2[tf.newaxis, :, :] |
| 144 | diff = tf.reshape(diff, tf.concat((Xshape[:-1], X2shape[:-1], [Xshape[-1]]), 0)) |
| 145 | return diff |
| 146 | |
| 147 | |
| 148 | @check_shapes( |
searching dependent graphs…