The covariance of q(x) can be Cholesky matrices or diagonal matrices. Here we make sure the behaviours overlap.
(white: bool)
| 133 | |
| 134 | @pytest.mark.parametrize("white", [True, False]) |
| 135 | def test_diags(white: bool) -> None: |
| 136 | """ |
| 137 | The covariance of q(x) can be Cholesky matrices or diagonal matrices. |
| 138 | Here we make sure the behaviours overlap. |
| 139 | """ |
| 140 | # the chols are diagonal matrices, with the same entries as the diag representation. |
| 141 | chol_from_diag = tf.stack( |
| 142 | [tf.linalg.diag(Datum.sqrt_diag[:, i]) for i in range(Datum.N)] # [N, M, M] |
| 143 | ) |
| 144 | kl_diag = gauss_kl(Datum.mu, Datum.sqrt_diag, Datum.K if white else None) |
| 145 | kl_dense = gauss_kl(Datum.mu, chol_from_diag, Datum.K if white else None) |
| 146 | |
| 147 | np.testing.assert_allclose(kl_diag, kl_dense) |
| 148 | |
| 149 | |
| 150 | @pytest.mark.parametrize("diag", [True, False]) |
nothing calls this directly
no test coverage detected
searching dependent graphs…