MCPcopy
hub / github.com/GPflow/GPflow / test_broadcasting

Function test_broadcasting

tests/gpflow/kernels/test_broadcasting.py:142–311  ·  view source on GitHub ↗
(
    kernel: kernels.Kernel,
    batch_shape: Tuple[int, ...],
    batch2_shape: Tuple[int, ...],
    rng: np.random.Generator,
)

Source from the content-addressed store, hash-verified

140)
141@check_shapes()
142def test_broadcasting(
143 kernel: kernels.Kernel,
144 batch_shape: Tuple[int, ...],
145 batch2_shape: Tuple[int, ...],
146 rng: np.random.Generator,
147) -> None:
148 if isinstance(kernel, kernels.Coregion):
149 D = 1
150 X: AnyNDArray = rng.choice(kernel.rank, batch_shape + (D,))
151 X2: AnyNDArray = rng.choice(kernel.rank, batch2_shape + (D,))
152 else:
153 if isinstance(kernel, kernels.ChangePoints):
154 D = 1
155 elif isinstance(kernel, kernels.Convolutional):
156 D = int(np.prod(kernel.image_shape))
157 else:
158 D = 4
159
160 X = rng.random(batch_shape + (D,))
161 X2 = rng.random(batch2_shape + (D,))
162
163 rank = len(batch_shape) - 1
164 rank2 = len(batch2_shape) - 1
165
166 if isinstance(kernel, kernels.MultioutputKernel):
167 mo_kernel: kernels.MultioutputKernel = kernel
168
169 loop = cs(
170 unroll_batches(
171 lambda x: unroll_batches(
172 lambda x2: mo_kernel(x, x2, full_cov=True, full_output_cov=True),
173 X2,
174 2,
175 ),
176 X,
177 2,
178 ),
179 "[batch..., batch2..., N, P, N2, P]",
180 )
181 loop = cs(
182 tf.transpose(
183 loop,
184 tf.concat(
185 [
186 np.arange(rank),
187 [rank + rank2, rank + rank2 + 1],
188 np.arange(rank2) + rank,
189 [rank + rank2 + 2, rank + rank2 + 3],
190 ],
191 0,
192 ),
193 ),
194 "[batch..., N, P, batch2..., N2, P]",
195 )
196 native = cs(
197 mo_kernel(X, X2, full_cov=True, full_output_cov=True),
198 "[batch..., N, P, batch2..., N2, P]",
199 )

Callers

nothing calls this directly

Calls 2

unroll_batchesFunction · 0.85
kernelFunction · 0.50

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…