(whiten: bool, full_cov: bool, full_output_cov: bool)
| 262 | |
| 263 | |
| 264 | def test_sample_conditional(whiten: bool, full_cov: bool, full_output_cov: bool) -> None: |
| 265 | if full_cov and full_output_cov: |
| 266 | return |
| 267 | |
| 268 | q_mu = tf.random.uniform((Data.M, Data.P), dtype=tf.float64) # [M, P] |
| 269 | q_sqrt = tf.convert_to_tensor( |
| 270 | [np.tril(tf.random.uniform((Data.M, Data.M), dtype=tf.float64)) for _ in range(Data.P)] |
| 271 | ) # [P, M, M] |
| 272 | |
| 273 | Z = Data.X[: Data.M, ...] # [M, D] |
| 274 | Xs: AnyNDArray = np.ones((Data.N, Data.D), dtype=float_type) |
| 275 | |
| 276 | inducing_variable = InducingPoints(Z) |
| 277 | kernel = SquaredExponential() |
| 278 | |
| 279 | # Path 1 |
| 280 | value_f, mean_f, var_f = sample_conditional( |
| 281 | Xs, |
| 282 | inducing_variable, |
| 283 | kernel, |
| 284 | q_mu, |
| 285 | q_sqrt=q_sqrt, |
| 286 | white=whiten, |
| 287 | full_cov=full_cov, |
| 288 | full_output_cov=full_output_cov, |
| 289 | num_samples=int(1e5), |
| 290 | ) |
| 291 | value_f = value_f.numpy().reshape((-1,) + value_f.numpy().shape[2:]) |
| 292 | |
| 293 | # Path 2 |
| 294 | if full_output_cov: |
| 295 | pytest.skip( |
| 296 | "sample_conditional with X instead of inducing_variable does not support full_output_cov" |
| 297 | ) |
| 298 | |
| 299 | value_x, mean_x, var_x = sample_conditional( |
| 300 | Xs, |
| 301 | Z, |
| 302 | kernel, |
| 303 | q_mu, |
| 304 | q_sqrt=q_sqrt, |
| 305 | white=whiten, |
| 306 | full_cov=full_cov, |
| 307 | full_output_cov=full_output_cov, |
| 308 | num_samples=int(1e5), |
| 309 | ) |
| 310 | value_x = value_x.numpy().reshape((-1,) + value_x.numpy().shape[2:]) |
| 311 | |
| 312 | # check if mean and covariance of samples are similar |
| 313 | np.testing.assert_array_almost_equal( |
| 314 | np.mean(value_x, axis=0), np.mean(value_f, axis=0), decimal=1 |
| 315 | ) |
| 316 | np.testing.assert_array_almost_equal( |
| 317 | np.cov(value_x, rowvar=False), np.cov(value_f, rowvar=False), decimal=1 |
| 318 | ) |
| 319 | np.testing.assert_allclose(mean_x, mean_f) |
| 320 | np.testing.assert_allclose(var_x, var_f) |
| 321 |
nothing calls this directly
no test coverage detected
searching dependent graphs…