(
leading_dims: Tuple[int, ...], n: int, d: int, num_samples: Optional[int], full_cov: bool
)
| 113 | @pytest.mark.parametrize("num_samples", [None, 1, 5]) |
| 114 | @pytest.mark.parametrize("full_cov", [True, False]) |
| 115 | def test_sample_mvn_shapes( |
| 116 | leading_dims: Tuple[int, ...], n: int, d: int, num_samples: Optional[int], full_cov: bool |
| 117 | ) -> None: |
| 118 | mean_shape = leading_dims + (n, d) |
| 119 | means = tf.zeros(mean_shape, dtype=default_float()) |
| 120 | |
| 121 | if full_cov: |
| 122 | covariance_shape = leading_dims + (n, d, d) |
| 123 | sqrt_cov = tf.random.normal(covariance_shape, dtype=default_float()) |
| 124 | covariances = tf.matmul(sqrt_cov, sqrt_cov, transpose_b=True) |
| 125 | else: |
| 126 | covariance_shape = leading_dims + (n, d) |
| 127 | covariances = tf.random.normal(covariance_shape, dtype=default_float()) |
| 128 | |
| 129 | samples = sample_mvn(means, covariances, full_cov, num_samples) |
| 130 | |
| 131 | if num_samples: |
| 132 | expected_shape = leading_dims + (num_samples, n, d) |
| 133 | else: |
| 134 | expected_shape = leading_dims + (n, d) |
| 135 | |
| 136 | assert_equal(samples.shape, expected_shape) |
nothing calls this directly
no test coverage detected
searching dependent graphs…