MCPcopy Index your code
hub / github.com/ddbourgin/numpy-ml / DiscreteSampler

Class DiscreteSampler

numpy_ml/utils/data_structures.py:346–470  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

344
345
346class DiscreteSampler:
347 def __init__(self, probs, log=False, with_replacement=True):
348 """
349 Sample from an arbitrary multinomial PMF over the first `N` nonnegative
350 integers using Vose's algorithm for the alias method.
351
352 Notes
353 -----
354 Vose's algorithm takes `O(n)` time to initialize, requires `O(n)` memory,
355 and generates samples in constant time.
356
357 References
358 ----------
359 .. [1] Walker, A. J. (1977) "An efficient method for generating discrete
360 random variables with general distributions". *ACM Transactions on
361 Mathematical Software, 3(3)*, 253-256.
362
363 .. [2] Vose, M. D. (1991) "A linear algorithm for generating random numbers
364 with a given distribution". *IEEE Trans. Softw. Eng., 9*, 972-974.
365
366 .. [3] Schwarz, K (2011) "Darts, dice, and coins: sampling from a discrete
367 distribution". http://www.keithschwarz.com/darts-dice-coins/
368
369 Parameters
370 ----------
371 probs : :py:class:`ndarray <numpy.ndarray>` of length `(N,)`
372 A list of probabilities of the `N` outcomes in the sample space.
373 `probs[i]` returns the probability of outcome `i`.
374 log : bool
375 Whether the probabilities in `probs` are in logspace. Default is
376 False.
377 with_replacement : bool
378 Whether to generate samples with or without replacement. Default is
379 True.
380 """
381 if not isinstance(probs, np.ndarray):
382 probs = np.array(probs)
383
384 self.log = log
385 self.N = len(probs)
386 self.probs = probs
387 self.with_replacement = with_replacement
388
389 alias = np.zeros(self.N)
390 prob = np.zeros(self.N)
391 scaled_probs = self.probs + np.log(self.N) if log else self.probs * self.N
392
393 selector = scaled_probs < 0 if log else scaled_probs < 1
394 small, large = np.where(selector)[0].tolist(), np.where(~selector)[0].tolist()
395
396 while len(small) and len(large):
397 l, g = small.pop(), large.pop()
398
399 alias[l] = g
400 prob[l] = scaled_probs[l]
401
402 if log:
403 pg = np.log(np.exp(scaled_probs[g]) + np.exp(scaled_probs[l]) - 1)

Callers 2

test_NCELossFunction · 0.90

Calls

no outgoing calls

Tested by 1

test_NCELossFunction · 0.72