MCPcopy
hub / github.com/GPflow/GPflow / test_svgp_multiclass

Function test_svgp_multiclass

tests/integration/test_dynamic_shapes.py:154–188  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

152
153
154def test_svgp_multiclass() -> None:
155 num_classes = 3
156 model = gpflow.models.SVGP(
157 gpflow.kernels.SquaredExponential(),
158 gpflow.likelihoods.MultiClass(num_classes=num_classes),
159 inducing_variable=Datum.X.copy(),
160 num_latent_gps=num_classes,
161 )
162 gpflow.set_trainable(model.inducing_variable, False)
163
164 # lambda because: https://github.com/GPflow/GPflow/issues/1929
165 elbo_lambda = lambda data: model.elbo(data)
166
167 # test with explicitly unknown shapes:
168 tensor_spec = tf.TensorSpec(shape=None, dtype=default_float())
169 elbo = tf.function(
170 elbo_lambda,
171 input_signature=[(tensor_spec, tensor_spec)],
172 )
173
174 @tf.function
175 def model_closure() -> tf.Tensor:
176 return -elbo(Datum.cdata)
177
178 model_closure() # Trigger compilation.
179
180 opt = gpflow.optimizers.Scipy()
181
182 # simply test whether it runs without erroring...:
183 opt.minimize(
184 model_closure,
185 variables=model.trainable_variables,
186 options=dict(maxiter=3),
187 compile=True,
188 )

Callers

nothing calls this directly

Calls 4

minimizeMethod · 0.95
default_floatFunction · 0.90
model_closureFunction · 0.85
elboMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…