| 291 | |
| 292 | @inherit_check_shapes |
| 293 | def __call__(self, X: TensorType) -> tf.Tensor: |
| 294 | ind = tf.gather(tf.transpose(X), tf.shape(X)[1] - 1) # ind = X[:,-1] |
| 295 | ind = tf.cast(ind, tf.int32) |
| 296 | X = tf.transpose( |
| 297 | tf.gather(tf.transpose(X), tf.range(0, tf.shape(X)[1] - 1)) |
| 298 | ) # X = X[:,:-1] |
| 299 | |
| 300 | # split up X into chunks corresponding to the relevant likelihoods |
| 301 | x_list = tf.dynamic_partition(X, ind, len(self.functions)) |
| 302 | # apply the likelihood-function to each section of the data |
| 303 | results = [m(x) for x, m in zip(x_list, self.functions)] |
| 304 | # stitch the results back together |
| 305 | partitions = tf.dynamic_partition(tf.range(0, tf.size(ind)), ind, len(self.functions)) |
| 306 | return tf.dynamic_stitch(partitions, results) |
| 307 | |
| 308 | |
| 309 | class SwitchedMeanFunction(SwitchedFunction): |