MCPcopy
hub / github.com/GPflow/GPflow / _embed

Method _embed

gpflow/kernels/hierarchical.py:388–403  ·  view source on GitHub ↗
(self, X: TensorType)

Source from the content-addressed store, hash-verified

386
387 @check_shapes("X: [batch..., N, D]", "return: [batch..., N, D_e]") # D_e = 2*D_cond + D_uncond
388 def _embed(self, X: TensorType) -> tf.Tensor:
389 v = self._normalise(X)
390 m_float = tf.cast(self._build_activity_mask(X), v.dtype)
391 parts = []
392 if self._n_uncond > 0:
393 v_unc = tf.gather(v, self._uncond_local_idx, axis=-1)
394 parts.append(v_unc / self.uncond_lengthscales)
395 if self._n_cond > 0:
396 v_c = tf.gather(v, self._cond_local_idx, axis=-1)
397 m_c = tf.gather(m_float, self._cond_local_idx, axis=-1)
398 parts.append(self._embed_conditional(v_c, m_c))
399 if not parts:
400 # Unreachable: HierarchyNode/__init__ guarantee >= 1 feature column.
401 shape = tf.concat([tf.shape(X)[:-1], [0]], axis=0) # pragma: no cover
402 return tf.zeros(shape, dtype=v.dtype) # pragma: no cover
403 return tf.concat(parts, axis=-1)
404
405 @inherit_check_shapes
406 def K(self, X: TensorType, X2: Optional[TensorType] = None) -> tf.Tensor:

Callers 5

KMethod · 0.95
K_diagMethod · 0.95
test_uncond_onlyMethod · 0.80

Calls 4

_normaliseMethod · 0.95
_build_activity_maskMethod · 0.95
_embed_conditionalMethod · 0.95
shapeMethod · 0.45