MCPcopy Index your code
hub / github.com/GPflow/GPflow / _sample_conditional

Function _sample_conditional

gpflow/conditionals/sample_conditionals.py:40–91  ·  view source on GitHub ↗

`sample_conditional` will return a sample from the conditional distribution. In most cases this means calculating the conditional mean m and variance v and then returning m + sqrt(v) * eps, with eps ~ N(0, 1). However, for some combinations of Mok and Mof more efficient sampling rou

(
    Xnew: tf.Tensor,
    inducing_variable: InducingVariables,
    kernel: Kernel,
    f: tf.Tensor,
    *,
    full_cov: bool = False,
    full_output_cov: bool = False,
    q_sqrt: Optional[tf.Tensor] = None,
    white: bool = False,
    num_samples: Optional[int] = None,
)

Source from the content-addressed store, hash-verified

38 "return[2]: [batch..., N, R, R] if (not full_cov) and full_output_cov",
39)
40def _sample_conditional(
41 Xnew: tf.Tensor,
42 inducing_variable: InducingVariables,
43 kernel: Kernel,
44 f: tf.Tensor,
45 *,
46 full_cov: bool = False,
47 full_output_cov: bool = False,
48 q_sqrt: Optional[tf.Tensor] = None,
49 white: bool = False,
50 num_samples: Optional[int] = None,
51) -> SamplesMeanAndVariance:
52 """
53 `sample_conditional` will return a sample from the conditional distribution.
54 In most cases this means calculating the conditional mean m and variance v and then
55 returning m + sqrt(v) * eps, with eps ~ N(0, 1).
56 However, for some combinations of Mok and Mof more efficient sampling routines exists.
57 The dispatcher will make sure that we use the most efficient one.
58
59 :return: samples, mean, cov
60 """
61
62 if full_cov and full_output_cov:
63 msg = "The combination of both `full_cov` and `full_output_cov` is not permitted."
64 raise NotImplementedError(msg)
65
66 mean, cov = conditional(
67 Xnew,
68 inducing_variable,
69 kernel,
70 f,
71 q_sqrt=q_sqrt,
72 white=white,
73 full_cov=full_cov,
74 full_output_cov=full_output_cov,
75 )
76 if full_cov:
77 # mean: [..., N, P]
78 # cov: [..., P, N, N]
79 mean_for_sample = tf.linalg.adjoint(mean) # [..., P, N]
80 samples = sample_mvn(
81 mean_for_sample, cov, full_cov=True, num_samples=num_samples
82 ) # [..., (S), P, N]
83 samples = tf.linalg.adjoint(samples) # [..., (S), N, P]
84 else:
85 # mean: [..., N, P]
86 # cov: [..., N, P] or [..., N, P, P]
87 samples = sample_mvn(
88 mean, cov, full_cov=full_output_cov, num_samples=num_samples
89 ) # [..., (S), N, P]
90
91 return samples, mean, cov

Callers

nothing calls this directly

Calls 1

sample_mvnFunction · 0.85

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…