MCPcopy Index your code
hub / github.com/dask/dask / test_choice

Function test_choice

dask/array/tests/test_random.py:306–372  ·  view source on GitHub ↗
(generator_class)

Source from the content-addressed store, hash-verified

304
305
306def test_choice(generator_class):
307 np_generator = {
308 da.random.RandomState: np.random.RandomState,
309 da.random.default_rng: np.random.default_rng,
310 }
311 np_dtype = np_generator[generator_class]().choice(1, size=()).dtype
312 size = (10, 3)
313 chunks = 4
314 x = generator_class().choice(3, size=size, chunks=chunks)
315 assert x.dtype == np_dtype
316 assert x.shape == size
317 res = x.compute()
318 assert res.dtype == np_dtype
319 assert res.shape == size
320
321 py_a = [1, 3, 5, 7, 9]
322 np_a = np.array(py_a, dtype="f8")
323 da_a = da.from_array(np_a, chunks=2)
324
325 for a in [py_a, np_a, da_a]:
326 x = generator_class().choice(a, size=size, chunks=chunks)
327 res = x.compute()
328 expected_dtype = np.asarray(a).dtype
329 assert x.dtype == expected_dtype
330 assert res.dtype == expected_dtype
331 assert set(np.unique(res)).issubset(np_a)
332
333 np_p = np.array([0, 0.2, 0.2, 0.3, 0.3])
334 da_p = da.from_array(np_p, chunks=2)
335
336 for a, p in [(da_a, np_p), (np_a, da_p)]:
337 x = generator_class().choice(a, size=size, chunks=chunks, p=p)
338 res = x.compute()
339 assert x.dtype == np_a.dtype
340 assert res.dtype == np_a.dtype
341 assert set(np.unique(res)).issubset(np_a[1:])
342
343 np_dtype = np_generator[generator_class]().choice(1, size=(), p=np.array([1])).dtype
344 x = generator_class().choice(5, size=size, chunks=chunks, p=np_p)
345 res = x.compute()
346 assert x.dtype == np_dtype
347 assert res.dtype == np_dtype
348
349 errs = [
350 (-1, None), # negative a
351 (np_a[:, None], None), # a must be 1D
352 (np_a, np_p[:, None]), # p must be 1D
353 (np_a, np_p[:-2]), # a and p must match
354 (3, np_p), # a and p must match
355 (4, [0.2, 0.2, 0.3]),
356 ] # p must sum to 1
357
358 for a, p in errs:
359 with pytest.raises(ValueError):
360 generator_class().choice(a, size=size, chunks=chunks, p=p)
361
362 with pytest.raises(NotImplementedError):
363 generator_class().choice(da_a, size=size, chunks=chunks, replace=False)

Callers

nothing calls this directly

Calls 5

generator_classFunction · 0.85
setClass · 0.85
choiceMethod · 0.45
computeMethod · 0.45
uniqueMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…