Generate random samples from the fitted Gaussian distribution. Parameters ---------- n_samples : int, default=1 Number of samples to generate. Returns ------- X : array, shape (n_samples, n_features) Randomly generated sample.
(self, n_samples=1)
| 432 | return xp.exp(log_resp) |
| 433 | |
| 434 | def sample(self, n_samples=1): |
| 435 | """Generate random samples from the fitted Gaussian distribution. |
| 436 | |
| 437 | Parameters |
| 438 | ---------- |
| 439 | n_samples : int, default=1 |
| 440 | Number of samples to generate. |
| 441 | |
| 442 | Returns |
| 443 | ------- |
| 444 | X : array, shape (n_samples, n_features) |
| 445 | Randomly generated sample. |
| 446 | |
| 447 | y : array, shape (nsamples,) |
| 448 | Component labels. |
| 449 | """ |
| 450 | check_is_fitted(self) |
| 451 | xp, _, device_ = get_namespace_and_device(self.means_) |
| 452 | |
| 453 | if n_samples < 1: |
| 454 | raise ValueError( |
| 455 | "Invalid value for 'n_samples': %d . The sampling requires at " |
| 456 | "least one sample." % (self.n_components) |
| 457 | ) |
| 458 | |
| 459 | _, n_features = self.means_.shape |
| 460 | rng = check_random_state(self.random_state) |
| 461 | n_samples_comp = rng.multinomial( |
| 462 | n_samples, move_to(self.weights_, xp=np, device="cpu") |
| 463 | ) |
| 464 | |
| 465 | if self.covariance_type == "full": |
| 466 | X = np.vstack( |
| 467 | [ |
| 468 | rng.multivariate_normal(mean, covariance, int(sample)) |
| 469 | for (mean, covariance, sample) in zip( |
| 470 | move_to(self.means_, xp=np, device="cpu"), |
| 471 | move_to(self.covariances_, xp=np, device="cpu"), |
| 472 | n_samples_comp, |
| 473 | ) |
| 474 | ] |
| 475 | ) |
| 476 | elif self.covariance_type == "tied": |
| 477 | X = np.vstack( |
| 478 | [ |
| 479 | rng.multivariate_normal( |
| 480 | mean, |
| 481 | move_to(self.covariances_, xp=np, device="cpu"), |
| 482 | int(sample), |
| 483 | ) |
| 484 | for (mean, sample) in zip( |
| 485 | move_to(self.means_, xp=np, device="cpu"), n_samples_comp |
| 486 | ) |
| 487 | ] |
| 488 | ) |
| 489 | else: |
| 490 | X = np.vstack( |
| 491 | [ |