| 386 | |
| 387 | @check_shapes("X: [batch..., N, D]", "return: [batch..., N, D_e]") # D_e = 2*D_cond + D_uncond |
| 388 | def _embed(self, X: TensorType) -> tf.Tensor: |
| 389 | v = self._normalise(X) |
| 390 | m_float = tf.cast(self._build_activity_mask(X), v.dtype) |
| 391 | parts = [] |
| 392 | if self._n_uncond > 0: |
| 393 | v_unc = tf.gather(v, self._uncond_local_idx, axis=-1) |
| 394 | parts.append(v_unc / self.uncond_lengthscales) |
| 395 | if self._n_cond > 0: |
| 396 | v_c = tf.gather(v, self._cond_local_idx, axis=-1) |
| 397 | m_c = tf.gather(m_float, self._cond_local_idx, axis=-1) |
| 398 | parts.append(self._embed_conditional(v_c, m_c)) |
| 399 | if not parts: |
| 400 | # Unreachable: HierarchyNode/__init__ guarantee >= 1 feature column. |
| 401 | shape = tf.concat([tf.shape(X)[:-1], [0]], axis=0) # pragma: no cover |
| 402 | return tf.zeros(shape, dtype=v.dtype) # pragma: no cover |
| 403 | return tf.concat(parts, axis=-1) |
| 404 | |
| 405 | @inherit_check_shapes |
| 406 | def K(self, X: TensorType, X2: Optional[TensorType] = None) -> tf.Tensor: |