Test for #725 and #734. When the shape of the Gaussian's mean had at least one unknown parameter, `gauss_kl` would blow up. This happened because `tf.size` can only output types `tf.int32` or `tf.int64`.
()
| 230 | |
| 231 | |
| 232 | def test_unknown_size_inputs() -> None: |
| 233 | """ |
| 234 | Test for #725 and #734. When the shape of the Gaussian's mean had at least |
| 235 | one unknown parameter, `gauss_kl` would blow up. This happened because |
| 236 | `tf.size` can only output types `tf.int32` or `tf.int64`. |
| 237 | """ |
| 238 | mu: AnyNDArray = np.ones([1, 4], dtype=default_float()) |
| 239 | sqrt: AnyNDArray = np.ones([4, 1, 1], dtype=default_float()) |
| 240 | |
| 241 | known_shape = gauss_kl(*map(tf.constant, [mu, sqrt])) |
| 242 | unknown_shape = gauss_kl(mu, sqrt) |
| 243 | |
| 244 | np.testing.assert_allclose(known_shape, unknown_shape) |
| 245 | |
| 246 | |
| 247 | @pytest.mark.parametrize("white", [True, False]) |
nothing calls this directly
no test coverage detected
searching dependent graphs…