()
| 118 | |
| 119 | |
| 120 | def test_vgp_multiclass() -> None: |
| 121 | X = tf.Variable( |
| 122 | tf.zeros((1, Datum.n_inputs), dtype=default_float()), shape=(None, None), trainable=False |
| 123 | ) |
| 124 | Yc = tf.Variable( |
| 125 | tf.zeros((1, Datum.n_outputs_c), dtype=default_float()), shape=(None, None), trainable=False |
| 126 | ) |
| 127 | |
| 128 | num_classes = 3 |
| 129 | model = gpflow.models.VGP( |
| 130 | (X, Yc), |
| 131 | gpflow.kernels.SquaredExponential(), |
| 132 | gpflow.likelihoods.MultiClass(num_classes=num_classes), |
| 133 | num_latent_gps=num_classes, |
| 134 | ) |
| 135 | |
| 136 | @tf.function |
| 137 | def model_closure() -> tf.Tensor: |
| 138 | return -model.elbo() |
| 139 | |
| 140 | model_closure() # Trigger compilation. |
| 141 | |
| 142 | gpflow.models.vgp.update_vgp_data(model, (Datum.X, Datum.Yc)) |
| 143 | opt = gpflow.optimizers.Scipy() |
| 144 | |
| 145 | # simply test whether it runs without erroring...: |
| 146 | opt.minimize( |
| 147 | model_closure, |
| 148 | variables=model.trainable_variables, |
| 149 | options=dict(maxiter=3), |
| 150 | compile=True, |
| 151 | ) |
| 152 | |
| 153 | |
| 154 | def test_svgp_multiclass() -> None: |
nothing calls this directly
no test coverage detected
searching dependent graphs…