MCPcopy
hub / github.com/GPflow/GPflow / test_sample_mvn_shapes

Function test_sample_mvn_shapes

tests/gpflow/conditionals/test_util.py:115–136  ·  view source on GitHub ↗
(
    leading_dims: Tuple[int, ...], n: int, d: int, num_samples: Optional[int], full_cov: bool
)

Source from the content-addressed store, hash-verified

113@pytest.mark.parametrize("num_samples", [None, 1, 5])
114@pytest.mark.parametrize("full_cov", [True, False])
115def test_sample_mvn_shapes(
116 leading_dims: Tuple[int, ...], n: int, d: int, num_samples: Optional[int], full_cov: bool
117) -> None:
118 mean_shape = leading_dims + (n, d)
119 means = tf.zeros(mean_shape, dtype=default_float())
120
121 if full_cov:
122 covariance_shape = leading_dims + (n, d, d)
123 sqrt_cov = tf.random.normal(covariance_shape, dtype=default_float())
124 covariances = tf.matmul(sqrt_cov, sqrt_cov, transpose_b=True)
125 else:
126 covariance_shape = leading_dims + (n, d)
127 covariances = tf.random.normal(covariance_shape, dtype=default_float())
128
129 samples = sample_mvn(means, covariances, full_cov, num_samples)
130
131 if num_samples:
132 expected_shape = leading_dims + (num_samples, n, d)
133 else:
134 expected_shape = leading_dims + (n, d)
135
136 assert_equal(samples.shape, expected_shape)

Callers

nothing calls this directly

Calls 2

default_floatFunction · 0.90
sample_mvnFunction · 0.90

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…