MCPcopy
hub / github.com/GPflow/GPflow / difference_matrix

Function difference_matrix

gpflow/utilities/ops.py:131–145  ·  view source on GitHub ↗

Returns (X - X2ᵀ)

(X: tf.Tensor, X2: Optional[tf.Tensor])

Source from the content-addressed store, hash-verified

129 "return: [batch..., N, N, D] if X2 is None",
130)
131def difference_matrix(X: tf.Tensor, X2: Optional[tf.Tensor]) -> tf.Tensor:
132 """
133 Returns (X - X2ᵀ)
134 """
135 if X2 is None:
136 X2 = X
137 diff = X[..., :, tf.newaxis, :] - X2[..., tf.newaxis, :, :]
138 return diff
139 Xshape = tf.shape(X)
140 X2shape = tf.shape(X2)
141 X = tf.reshape(X, (-1, Xshape[-1]))
142 X2 = tf.reshape(X2, (-1, X2shape[-1]))
143 diff = X[:, tf.newaxis, :] - X2[tf.newaxis, :, :]
144 diff = tf.reshape(diff, tf.concat((Xshape[:-1], X2shape[:-1], [Xshape[-1]]), 0))
145 return diff
146
147
148@check_shapes(

Calls 1

shapeMethod · 0.45

Used in the wild real call sites across dependent graphs

searching dependent graphs…