MCPcopy
hub / github.com/GPflow/GPflow / sample_mvn

Function sample_mvn

gpflow/conditionals/util.py:179–211  ·  view source on GitHub ↗

Returns a sample from a D-dimensional Multivariate Normal distribution. :return: sample from the MVN

(
    mean: tf.Tensor, cov: tf.Tensor, full_cov: bool, num_samples: Optional[int] = None
)

Source from the content-addressed store, hash-verified

177 "return: [batch..., S, N, D] if num_samples is not None",
178)
179def sample_mvn(
180 mean: tf.Tensor, cov: tf.Tensor, full_cov: bool, num_samples: Optional[int] = None
181) -> tf.Tensor:
182 """
183 Returns a sample from a D-dimensional Multivariate Normal distribution.
184
185 :return: sample from the MVN
186 """
187 mean_shape = tf.shape(mean)
188 S = num_samples if num_samples is not None else 1
189 D = mean_shape[-1]
190 leading_dims = mean_shape[:-2]
191
192 if not full_cov:
193 # mean: [..., N, D] and cov [..., N, D]
194 eps_shape = tf.concat([leading_dims, [S], mean_shape[-2:]], 0)
195 eps = tf.random.normal(eps_shape, dtype=default_float()) # [..., S, N, D]
196 samples = mean[..., None, :, :] + tf.sqrt(cov)[..., None, :, :] * eps # [..., S, N, D]
197
198 else:
199 # mean: [..., N, D] and cov [..., N, D, D]
200 jittermat = (
201 tf.eye(D, batch_shape=mean_shape[:-1], dtype=default_float()) * default_jitter()
202 ) # [..., N, D, D]
203 eps_shape = tf.concat([mean_shape, [S]], 0)
204 eps = tf.random.normal(eps_shape, dtype=default_float()) # [..., N, D, S]
205 chol = tf.linalg.cholesky(cov + jittermat) # [..., N, D, D]
206 samples = mean[..., None] + tf.linalg.matmul(chol, eps) # [..., N, D, S]
207 samples = leading_transpose(samples, [..., -1, -3, -2]) # [..., S, N, D]
208
209 if num_samples is None:
210 return tf.squeeze(samples, axis=-3) # [..., N, D]
211 return samples # [..., S, N, D]
212
213
214@check_shapes(

Callers 5

test_sample_mvn_shapesFunction · 0.90
test_sample_mvnFunction · 0.90
_sample_conditionalFunction · 0.85
_sample_conditionalFunction · 0.85
predict_f_samplesMethod · 0.85

Calls 4

default_floatFunction · 0.85
default_jitterFunction · 0.85
leading_transposeFunction · 0.85
shapeMethod · 0.45

Tested by 2

test_sample_mvn_shapesFunction · 0.72
test_sample_mvnFunction · 0.72

Used in the wild real call sites across dependent graphs

searching dependent graphs…