`sample_conditional` will return a sample from the conditional distribution. In most cases this means calculating the conditional mean m and variance v and then returning m + sqrt(v) * eps, with eps ~ N(0, 1). However, for some combinations of Mok and Mof more efficient sampling rou
(
Xnew: tf.Tensor,
inducing_variable: InducingVariables,
kernel: Kernel,
f: tf.Tensor,
*,
full_cov: bool = False,
full_output_cov: bool = False,
q_sqrt: Optional[tf.Tensor] = None,
white: bool = False,
num_samples: Optional[int] = None,
)
| 38 | "return[2]: [batch..., N, R, R] if (not full_cov) and full_output_cov", |
| 39 | ) |
| 40 | def _sample_conditional( |
| 41 | Xnew: tf.Tensor, |
| 42 | inducing_variable: InducingVariables, |
| 43 | kernel: Kernel, |
| 44 | f: tf.Tensor, |
| 45 | *, |
| 46 | full_cov: bool = False, |
| 47 | full_output_cov: bool = False, |
| 48 | q_sqrt: Optional[tf.Tensor] = None, |
| 49 | white: bool = False, |
| 50 | num_samples: Optional[int] = None, |
| 51 | ) -> SamplesMeanAndVariance: |
| 52 | """ |
| 53 | `sample_conditional` will return a sample from the conditional distribution. |
| 54 | In most cases this means calculating the conditional mean m and variance v and then |
| 55 | returning m + sqrt(v) * eps, with eps ~ N(0, 1). |
| 56 | However, for some combinations of Mok and Mof more efficient sampling routines exists. |
| 57 | The dispatcher will make sure that we use the most efficient one. |
| 58 | |
| 59 | :return: samples, mean, cov |
| 60 | """ |
| 61 | |
| 62 | if full_cov and full_output_cov: |
| 63 | msg = "The combination of both `full_cov` and `full_output_cov` is not permitted." |
| 64 | raise NotImplementedError(msg) |
| 65 | |
| 66 | mean, cov = conditional( |
| 67 | Xnew, |
| 68 | inducing_variable, |
| 69 | kernel, |
| 70 | f, |
| 71 | q_sqrt=q_sqrt, |
| 72 | white=white, |
| 73 | full_cov=full_cov, |
| 74 | full_output_cov=full_output_cov, |
| 75 | ) |
| 76 | if full_cov: |
| 77 | # mean: [..., N, P] |
| 78 | # cov: [..., P, N, N] |
| 79 | mean_for_sample = tf.linalg.adjoint(mean) # [..., P, N] |
| 80 | samples = sample_mvn( |
| 81 | mean_for_sample, cov, full_cov=True, num_samples=num_samples |
| 82 | ) # [..., (S), P, N] |
| 83 | samples = tf.linalg.adjoint(samples) # [..., (S), N, P] |
| 84 | else: |
| 85 | # mean: [..., N, P] |
| 86 | # cov: [..., N, P] or [..., N, P, P] |
| 87 | samples = sample_mvn( |
| 88 | mean, cov, full_cov=full_output_cov, num_samples=num_samples |
| 89 | ) # [..., (S), N, P] |
| 90 | |
| 91 | return samples, mean, cov |
nothing calls this directly
no test coverage detected
searching dependent graphs…