| 344 | |
| 345 | |
| 346 | class DiscreteSampler: |
| 347 | def __init__(self, probs, log=False, with_replacement=True): |
| 348 | """ |
| 349 | Sample from an arbitrary multinomial PMF over the first `N` nonnegative |
| 350 | integers using Vose's algorithm for the alias method. |
| 351 | |
| 352 | Notes |
| 353 | ----- |
| 354 | Vose's algorithm takes `O(n)` time to initialize, requires `O(n)` memory, |
| 355 | and generates samples in constant time. |
| 356 | |
| 357 | References |
| 358 | ---------- |
| 359 | .. [1] Walker, A. J. (1977) "An efficient method for generating discrete |
| 360 | random variables with general distributions". *ACM Transactions on |
| 361 | Mathematical Software, 3(3)*, 253-256. |
| 362 | |
| 363 | .. [2] Vose, M. D. (1991) "A linear algorithm for generating random numbers |
| 364 | with a given distribution". *IEEE Trans. Softw. Eng., 9*, 972-974. |
| 365 | |
| 366 | .. [3] Schwarz, K (2011) "Darts, dice, and coins: sampling from a discrete |
| 367 | distribution". http://www.keithschwarz.com/darts-dice-coins/ |
| 368 | |
| 369 | Parameters |
| 370 | ---------- |
| 371 | probs : :py:class:`ndarray <numpy.ndarray>` of length `(N,)` |
| 372 | A list of probabilities of the `N` outcomes in the sample space. |
| 373 | `probs[i]` returns the probability of outcome `i`. |
| 374 | log : bool |
| 375 | Whether the probabilities in `probs` are in logspace. Default is |
| 376 | False. |
| 377 | with_replacement : bool |
| 378 | Whether to generate samples with or without replacement. Default is |
| 379 | True. |
| 380 | """ |
| 381 | if not isinstance(probs, np.ndarray): |
| 382 | probs = np.array(probs) |
| 383 | |
| 384 | self.log = log |
| 385 | self.N = len(probs) |
| 386 | self.probs = probs |
| 387 | self.with_replacement = with_replacement |
| 388 | |
| 389 | alias = np.zeros(self.N) |
| 390 | prob = np.zeros(self.N) |
| 391 | scaled_probs = self.probs + np.log(self.N) if log else self.probs * self.N |
| 392 | |
| 393 | selector = scaled_probs < 0 if log else scaled_probs < 1 |
| 394 | small, large = np.where(selector)[0].tolist(), np.where(~selector)[0].tolist() |
| 395 | |
| 396 | while len(small) and len(large): |
| 397 | l, g = small.pop(), large.pop() |
| 398 | |
| 399 | alias[l] = g |
| 400 | prob[l] = scaled_probs[l] |
| 401 | |
| 402 | if log: |
| 403 | pg = np.log(np.exp(scaled_probs[g]) + np.exp(scaled_probs[l]) - 1) |
no outgoing calls