Returns a sample from a D-dimensional Multivariate Normal distribution. :return: sample from the MVN
(
mean: tf.Tensor, cov: tf.Tensor, full_cov: bool, num_samples: Optional[int] = None
)
| 177 | "return: [batch..., S, N, D] if num_samples is not None", |
| 178 | ) |
| 179 | def sample_mvn( |
| 180 | mean: tf.Tensor, cov: tf.Tensor, full_cov: bool, num_samples: Optional[int] = None |
| 181 | ) -> tf.Tensor: |
| 182 | """ |
| 183 | Returns a sample from a D-dimensional Multivariate Normal distribution. |
| 184 | |
| 185 | :return: sample from the MVN |
| 186 | """ |
| 187 | mean_shape = tf.shape(mean) |
| 188 | S = num_samples if num_samples is not None else 1 |
| 189 | D = mean_shape[-1] |
| 190 | leading_dims = mean_shape[:-2] |
| 191 | |
| 192 | if not full_cov: |
| 193 | # mean: [..., N, D] and cov [..., N, D] |
| 194 | eps_shape = tf.concat([leading_dims, [S], mean_shape[-2:]], 0) |
| 195 | eps = tf.random.normal(eps_shape, dtype=default_float()) # [..., S, N, D] |
| 196 | samples = mean[..., None, :, :] + tf.sqrt(cov)[..., None, :, :] * eps # [..., S, N, D] |
| 197 | |
| 198 | else: |
| 199 | # mean: [..., N, D] and cov [..., N, D, D] |
| 200 | jittermat = ( |
| 201 | tf.eye(D, batch_shape=mean_shape[:-1], dtype=default_float()) * default_jitter() |
| 202 | ) # [..., N, D, D] |
| 203 | eps_shape = tf.concat([mean_shape, [S]], 0) |
| 204 | eps = tf.random.normal(eps_shape, dtype=default_float()) # [..., N, D, S] |
| 205 | chol = tf.linalg.cholesky(cov + jittermat) # [..., N, D, D] |
| 206 | samples = mean[..., None] + tf.linalg.matmul(chol, eps) # [..., N, D, S] |
| 207 | samples = leading_transpose(samples, [..., -1, -3, -2]) # [..., S, N, D] |
| 208 | |
| 209 | if num_samples is None: |
| 210 | return tf.squeeze(samples, axis=-3) # [..., N, D] |
| 211 | return samples # [..., S, N, D] |
| 212 | |
| 213 | |
| 214 | @check_shapes( |
searching dependent graphs…