Compute the expectation _p(x) Uses multiple-dispatch to select an analytical implementation, if one is available. If not, it falls back to quadrature. :type p: (mu, cov) tuple or a `ProbabilityDistribution` object :type obj1: kernel, mean function, (kernel, ind
(
p: ProbabilityDistributionLike,
obj1: PackedExpectationObject,
obj2: PackedExpectationObject = None,
nghp: Optional[int] = None,
)
| 39 | |
| 40 | |
| 41 | def expectation( |
| 42 | p: ProbabilityDistributionLike, |
| 43 | obj1: PackedExpectationObject, |
| 44 | obj2: PackedExpectationObject = None, |
| 45 | nghp: Optional[int] = None, |
| 46 | ) -> tf.Tensor: |
| 47 | """ |
| 48 | Compute the expectation <obj1(x) obj2(x)>_p(x) |
| 49 | Uses multiple-dispatch to select an analytical implementation, |
| 50 | if one is available. If not, it falls back to quadrature. |
| 51 | |
| 52 | :type p: (mu, cov) tuple or a `ProbabilityDistribution` object |
| 53 | :type obj1: kernel, mean function, (kernel, inducing_variable), or None |
| 54 | :type obj2: kernel, mean function, (kernel, inducing_variable), or None |
| 55 | :param int nghp: passed to `_quadrature_expectation` to set the number |
| 56 | of Gauss-Hermite points used: `num_gauss_hermite_points` |
| 57 | :return: a 1-D, 2-D, or 3-D tensor containing the expectation |
| 58 | |
| 59 | Allowed combinations |
| 60 | |
| 61 | - Psi statistics: |
| 62 | >>> eKdiag = expectation(p, kernel) (N) # Psi0 |
| 63 | >>> eKxz = expectation(p, (kernel, inducing_variable)) (NxM) # Psi1 |
| 64 | >>> exKxz = expectation(p, identity_mean, (kernel, inducing_variable)) (NxDxM) |
| 65 | >>> eKzxKxz = expectation(p, (kernel, inducing_variable), (kernel, inducing_variable)) (NxMxM) # Psi2 |
| 66 | |
| 67 | - kernels and mean functions: |
| 68 | >>> eKzxMx = expectation(p, (kernel, inducing_variable), mean) (NxMxQ) |
| 69 | >>> eMxKxz = expectation(p, mean, (kernel, inducing_variable)) (NxQxM) |
| 70 | |
| 71 | - only mean functions: |
| 72 | >>> eMx = expectation(p, mean) (NxQ) |
| 73 | >>> eM1x_M2x = expectation(p, mean1, mean2) (NxQ1xQ2) |
| 74 | .. note:: mean(x) is 1xQ (row vector) |
| 75 | |
| 76 | - different kernels. This occurs, for instance, when we are calculating Psi2 for Sum kernels: |
| 77 | >>> eK1zxK2xz = expectation(p, (kern1, inducing_variable), (kern2, inducing_variable)) (NxMxM) |
| 78 | """ |
| 79 | p, obj1, feat1, obj2, feat2 = _init_expectation(p, obj1, obj2) |
| 80 | try: |
| 81 | return dispatch.expectation(p, obj1, feat1, obj2, feat2, nghp=nghp) |
| 82 | except NotImplementedError: |
| 83 | return dispatch.quadrature_expectation(p, obj1, feat1, obj2, feat2, nghp=nghp) |
| 84 | |
| 85 | |
| 86 | def quadrature_expectation( |
searching dependent graphs…