MCPcopy
hub / github.com/GPflow/GPflow / test_sample_conditional

Function test_sample_conditional

tests/gpflow/conditionals/test_multioutput.py:264–320  ·  view source on GitHub ↗
(whiten: bool, full_cov: bool, full_output_cov: bool)

Source from the content-addressed store, hash-verified

262
263
264def test_sample_conditional(whiten: bool, full_cov: bool, full_output_cov: bool) -> None:
265 if full_cov and full_output_cov:
266 return
267
268 q_mu = tf.random.uniform((Data.M, Data.P), dtype=tf.float64) # [M, P]
269 q_sqrt = tf.convert_to_tensor(
270 [np.tril(tf.random.uniform((Data.M, Data.M), dtype=tf.float64)) for _ in range(Data.P)]
271 ) # [P, M, M]
272
273 Z = Data.X[: Data.M, ...] # [M, D]
274 Xs: AnyNDArray = np.ones((Data.N, Data.D), dtype=float_type)
275
276 inducing_variable = InducingPoints(Z)
277 kernel = SquaredExponential()
278
279 # Path 1
280 value_f, mean_f, var_f = sample_conditional(
281 Xs,
282 inducing_variable,
283 kernel,
284 q_mu,
285 q_sqrt=q_sqrt,
286 white=whiten,
287 full_cov=full_cov,
288 full_output_cov=full_output_cov,
289 num_samples=int(1e5),
290 )
291 value_f = value_f.numpy().reshape((-1,) + value_f.numpy().shape[2:])
292
293 # Path 2
294 if full_output_cov:
295 pytest.skip(
296 "sample_conditional with X instead of inducing_variable does not support full_output_cov"
297 )
298
299 value_x, mean_x, var_x = sample_conditional(
300 Xs,
301 Z,
302 kernel,
303 q_mu,
304 q_sqrt=q_sqrt,
305 white=whiten,
306 full_cov=full_cov,
307 full_output_cov=full_output_cov,
308 num_samples=int(1e5),
309 )
310 value_x = value_x.numpy().reshape((-1,) + value_x.numpy().shape[2:])
311
312 # check if mean and covariance of samples are similar
313 np.testing.assert_array_almost_equal(
314 np.mean(value_x, axis=0), np.mean(value_f, axis=0), decimal=1
315 )
316 np.testing.assert_array_almost_equal(
317 np.cov(value_x, rowvar=False), np.cov(value_f, rowvar=False), decimal=1
318 )
319 np.testing.assert_allclose(mean_x, mean_f)
320 np.testing.assert_allclose(var_x, var_f)
321

Callers

nothing calls this directly

Calls 2

InducingPointsClass · 0.90
SquaredExponentialClass · 0.90

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…