MCPcopy
hub / github.com/GPflow/GPflow / _sample_conditional

Function _sample_conditional

gpflow/conditionals/multioutput/sample_conditionals.py:42–77  ·  view source on GitHub ↗

`sample_conditional` will return a sample from the conditional distribution. In most cases this means calculating the conditional mean m and variance v and then returning m + sqrt(v) * eps, with eps ~ N(0, 1). However, for some combinations of Mok and Mof, more efficient samplin

(
    Xnew: tf.Tensor,
    inducing_variable: SharedIndependentInducingVariables,
    kernel: LinearCoregionalization,
    f: tf.Tensor,
    *,
    full_cov: bool = False,
    full_output_cov: bool = False,
    q_sqrt: Optional[tf.Tensor] = None,
    white: bool = False,
    num_samples: Optional[int] = None,
)

Source from the content-addressed store, hash-verified

40 "return[2]: [batch..., N, P]",
41)
42def _sample_conditional(
43 Xnew: tf.Tensor,
44 inducing_variable: SharedIndependentInducingVariables,
45 kernel: LinearCoregionalization,
46 f: tf.Tensor,
47 *,
48 full_cov: bool = False,
49 full_output_cov: bool = False,
50 q_sqrt: Optional[tf.Tensor] = None,
51 white: bool = False,
52 num_samples: Optional[int] = None,
53) -> SamplesMeanAndVariance:
54 """
55 `sample_conditional` will return a sample from the conditional distribution.
56 In most cases this means calculating the conditional mean m and variance v and then
57 returning m + sqrt(v) * eps, with eps ~ N(0, 1).
58 However, for some combinations of Mok and Mof, more efficient sampling routines exist.
59 The dispatcher will make sure that we use the most efficent one.
60
61 :return: samples, mean, cov
62 """
63 if full_cov:
64 raise NotImplementedError("full_cov not yet implemented")
65 if full_output_cov:
66 raise NotImplementedError("full_output_cov not yet implemented")
67
68 ind_conditional = conditional.dispatch_or_raise(
69 object, SeparateIndependentInducingVariables, SeparateIndependent, object
70 )
71 g_mu, g_var = ind_conditional(
72 Xnew, inducing_variable, kernel, f, white=white, q_sqrt=q_sqrt
73 ) # [..., N, L], [..., N, L]
74 g_sample = sample_mvn(g_mu, g_var, full_cov, num_samples=num_samples) # [..., (S), N, L]
75 f_mu, f_var = mix_latent_gp(kernel.W, g_mu, g_var, full_cov, full_output_cov)
76 f_sample = tf.tensordot(g_sample, kernel.W, [[-1], [-1]]) # [..., N, P]
77 return f_sample, f_mu, f_var

Callers

nothing calls this directly

Calls 3

sample_mvnFunction · 0.85
mix_latent_gpFunction · 0.85
dispatch_or_raiseMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…