(D_out: int, *shape: int)
| 91 | |
| 92 | |
| 93 | def gen_q_sqrt(D_out: int, *shape: int) -> tf.Tensor: |
| 94 | return tf.convert_to_tensor( |
| 95 | np.array([np.tril(rng.randn(*shape)) for _ in range(D_out)]), |
| 96 | dtype=default_float(), |
| 97 | ) |
| 98 | |
| 99 | |
| 100 | def mean_function_factory( |
no test coverage detected
searching dependent graphs…