()
| 152 | |
| 153 | |
| 154 | def test_svgp_multiclass() -> None: |
| 155 | num_classes = 3 |
| 156 | model = gpflow.models.SVGP( |
| 157 | gpflow.kernels.SquaredExponential(), |
| 158 | gpflow.likelihoods.MultiClass(num_classes=num_classes), |
| 159 | inducing_variable=Datum.X.copy(), |
| 160 | num_latent_gps=num_classes, |
| 161 | ) |
| 162 | gpflow.set_trainable(model.inducing_variable, False) |
| 163 | |
| 164 | # lambda because: https://github.com/GPflow/GPflow/issues/1929 |
| 165 | elbo_lambda = lambda data: model.elbo(data) |
| 166 | |
| 167 | # test with explicitly unknown shapes: |
| 168 | tensor_spec = tf.TensorSpec(shape=None, dtype=default_float()) |
| 169 | elbo = tf.function( |
| 170 | elbo_lambda, |
| 171 | input_signature=[(tensor_spec, tensor_spec)], |
| 172 | ) |
| 173 | |
| 174 | @tf.function |
| 175 | def model_closure() -> tf.Tensor: |
| 176 | return -elbo(Datum.cdata) |
| 177 | |
| 178 | model_closure() # Trigger compilation. |
| 179 | |
| 180 | opt = gpflow.optimizers.Scipy() |
| 181 | |
| 182 | # simply test whether it runs without erroring...: |
| 183 | opt.minimize( |
| 184 | model_closure, |
| 185 | variables=model.trainable_variables, |
| 186 | options=dict(maxiter=3), |
| 187 | compile=True, |
| 188 | ) |
nothing calls this directly
no test coverage detected
searching dependent graphs…