(generator_class)
| 304 | |
| 305 | |
| 306 | def test_choice(generator_class): |
| 307 | np_generator = { |
| 308 | da.random.RandomState: np.random.RandomState, |
| 309 | da.random.default_rng: np.random.default_rng, |
| 310 | } |
| 311 | np_dtype = np_generator[generator_class]().choice(1, size=()).dtype |
| 312 | size = (10, 3) |
| 313 | chunks = 4 |
| 314 | x = generator_class().choice(3, size=size, chunks=chunks) |
| 315 | assert x.dtype == np_dtype |
| 316 | assert x.shape == size |
| 317 | res = x.compute() |
| 318 | assert res.dtype == np_dtype |
| 319 | assert res.shape == size |
| 320 | |
| 321 | py_a = [1, 3, 5, 7, 9] |
| 322 | np_a = np.array(py_a, dtype="f8") |
| 323 | da_a = da.from_array(np_a, chunks=2) |
| 324 | |
| 325 | for a in [py_a, np_a, da_a]: |
| 326 | x = generator_class().choice(a, size=size, chunks=chunks) |
| 327 | res = x.compute() |
| 328 | expected_dtype = np.asarray(a).dtype |
| 329 | assert x.dtype == expected_dtype |
| 330 | assert res.dtype == expected_dtype |
| 331 | assert set(np.unique(res)).issubset(np_a) |
| 332 | |
| 333 | np_p = np.array([0, 0.2, 0.2, 0.3, 0.3]) |
| 334 | da_p = da.from_array(np_p, chunks=2) |
| 335 | |
| 336 | for a, p in [(da_a, np_p), (np_a, da_p)]: |
| 337 | x = generator_class().choice(a, size=size, chunks=chunks, p=p) |
| 338 | res = x.compute() |
| 339 | assert x.dtype == np_a.dtype |
| 340 | assert res.dtype == np_a.dtype |
| 341 | assert set(np.unique(res)).issubset(np_a[1:]) |
| 342 | |
| 343 | np_dtype = np_generator[generator_class]().choice(1, size=(), p=np.array([1])).dtype |
| 344 | x = generator_class().choice(5, size=size, chunks=chunks, p=np_p) |
| 345 | res = x.compute() |
| 346 | assert x.dtype == np_dtype |
| 347 | assert res.dtype == np_dtype |
| 348 | |
| 349 | errs = [ |
| 350 | (-1, None), # negative a |
| 351 | (np_a[:, None], None), # a must be 1D |
| 352 | (np_a, np_p[:, None]), # p must be 1D |
| 353 | (np_a, np_p[:-2]), # a and p must match |
| 354 | (3, np_p), # a and p must match |
| 355 | (4, [0.2, 0.2, 0.3]), |
| 356 | ] # p must sum to 1 |
| 357 | |
| 358 | for a, p in errs: |
| 359 | with pytest.raises(ValueError): |
| 360 | generator_class().choice(a, size=size, chunks=chunks, p=p) |
| 361 | |
| 362 | with pytest.raises(NotImplementedError): |
| 363 | generator_class().choice(da_a, size=size, chunks=chunks, replace=False) |
nothing calls this directly
no test coverage detected
searching dependent graphs…