(
kernel: kernels.Kernel,
batch_shape: Tuple[int, ...],
batch2_shape: Tuple[int, ...],
rng: np.random.Generator,
)
| 140 | ) |
| 141 | @check_shapes() |
| 142 | def test_broadcasting( |
| 143 | kernel: kernels.Kernel, |
| 144 | batch_shape: Tuple[int, ...], |
| 145 | batch2_shape: Tuple[int, ...], |
| 146 | rng: np.random.Generator, |
| 147 | ) -> None: |
| 148 | if isinstance(kernel, kernels.Coregion): |
| 149 | D = 1 |
| 150 | X: AnyNDArray = rng.choice(kernel.rank, batch_shape + (D,)) |
| 151 | X2: AnyNDArray = rng.choice(kernel.rank, batch2_shape + (D,)) |
| 152 | else: |
| 153 | if isinstance(kernel, kernels.ChangePoints): |
| 154 | D = 1 |
| 155 | elif isinstance(kernel, kernels.Convolutional): |
| 156 | D = int(np.prod(kernel.image_shape)) |
| 157 | else: |
| 158 | D = 4 |
| 159 | |
| 160 | X = rng.random(batch_shape + (D,)) |
| 161 | X2 = rng.random(batch2_shape + (D,)) |
| 162 | |
| 163 | rank = len(batch_shape) - 1 |
| 164 | rank2 = len(batch2_shape) - 1 |
| 165 | |
| 166 | if isinstance(kernel, kernels.MultioutputKernel): |
| 167 | mo_kernel: kernels.MultioutputKernel = kernel |
| 168 | |
| 169 | loop = cs( |
| 170 | unroll_batches( |
| 171 | lambda x: unroll_batches( |
| 172 | lambda x2: mo_kernel(x, x2, full_cov=True, full_output_cov=True), |
| 173 | X2, |
| 174 | 2, |
| 175 | ), |
| 176 | X, |
| 177 | 2, |
| 178 | ), |
| 179 | "[batch..., batch2..., N, P, N2, P]", |
| 180 | ) |
| 181 | loop = cs( |
| 182 | tf.transpose( |
| 183 | loop, |
| 184 | tf.concat( |
| 185 | [ |
| 186 | np.arange(rank), |
| 187 | [rank + rank2, rank + rank2 + 1], |
| 188 | np.arange(rank2) + rank, |
| 189 | [rank + rank2 + 2, rank + rank2 + 3], |
| 190 | ], |
| 191 | 0, |
| 192 | ), |
| 193 | ), |
| 194 | "[batch..., N, P, batch2..., N2, P]", |
| 195 | ) |
| 196 | native = cs( |
| 197 | mo_kernel(X, X2, full_cov=True, full_output_cov=True), |
| 198 | "[batch..., N, P, batch2..., N2, P]", |
| 199 | ) |
nothing calls this directly
no test coverage detected
searching dependent graphs…