MCPcopy
hub / github.com/GPflow/GPflow / test_vgp_multiclass

Function test_vgp_multiclass

tests/integration/test_dynamic_shapes.py:120–151  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

118
119
120def test_vgp_multiclass() -> None:
121 X = tf.Variable(
122 tf.zeros((1, Datum.n_inputs), dtype=default_float()), shape=(None, None), trainable=False
123 )
124 Yc = tf.Variable(
125 tf.zeros((1, Datum.n_outputs_c), dtype=default_float()), shape=(None, None), trainable=False
126 )
127
128 num_classes = 3
129 model = gpflow.models.VGP(
130 (X, Yc),
131 gpflow.kernels.SquaredExponential(),
132 gpflow.likelihoods.MultiClass(num_classes=num_classes),
133 num_latent_gps=num_classes,
134 )
135
136 @tf.function
137 def model_closure() -> tf.Tensor:
138 return -model.elbo()
139
140 model_closure() # Trigger compilation.
141
142 gpflow.models.vgp.update_vgp_data(model, (Datum.X, Datum.Yc))
143 opt = gpflow.optimizers.Scipy()
144
145 # simply test whether it runs without erroring...:
146 opt.minimize(
147 model_closure,
148 variables=model.trainable_variables,
149 options=dict(maxiter=3),
150 compile=True,
151 )
152
153
154def test_svgp_multiclass() -> None:

Callers

nothing calls this directly

Calls 3

minimizeMethod · 0.95
default_floatFunction · 0.90
model_closureFunction · 0.85

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…