(state, a, size, replace, p, axis, chunks)
| 803 | |
| 804 | |
| 805 | def _choice_validate_params(state, a, size, replace, p, axis, chunks): |
| 806 | dependencies = [] |
| 807 | # Normalize and validate `a` |
| 808 | if isinstance(a, Integral): |
| 809 | if isinstance(state, Generator): |
| 810 | if state._backend_name == "cupy": |
| 811 | raise NotImplementedError( |
| 812 | "`choice` not supported for cupy-backed `Generator`." |
| 813 | ) |
| 814 | meta = state._backend.random.default_rng().choice(1, size=(), p=None) |
| 815 | elif isinstance(state, RandomState): |
| 816 | # On windows the output dtype differs if p is provided or |
| 817 | # # absent, see https://github.com/numpy/numpy/issues/9867 |
| 818 | dummy_p = state._backend.array([1]) if p is not None else p |
| 819 | meta = state._backend.random.RandomState().choice(1, size=(), p=dummy_p) |
| 820 | else: |
| 821 | raise ValueError("Unknown generator class") |
| 822 | len_a = a |
| 823 | if a < 0: |
| 824 | raise ValueError("a must be greater than 0") |
| 825 | else: |
| 826 | a = asarray(a) |
| 827 | a = a.rechunk(a.shape) |
| 828 | meta = a._meta |
| 829 | if a.ndim != 1: |
| 830 | raise ValueError("a must be one dimensional") |
| 831 | len_a = len(a) |
| 832 | dependencies.append(a) |
| 833 | a = a.__dask_keys__()[0] |
| 834 | |
| 835 | # Normalize and validate `p` |
| 836 | if p is not None: |
| 837 | if not isinstance(p, Array): |
| 838 | # If p is not a dask array, first check the sum is close |
| 839 | # to 1 before converting. |
| 840 | p = asarray_safe(p, like=p) |
| 841 | if not np.isclose(p.sum(), 1, rtol=1e-7, atol=0): |
| 842 | raise ValueError("probabilities do not sum to 1") |
| 843 | p = asarray(p) |
| 844 | else: |
| 845 | p = p.rechunk(p.shape) |
| 846 | |
| 847 | if p.ndim != 1: |
| 848 | raise ValueError("p must be one dimensional") |
| 849 | if len(p) != len_a: |
| 850 | raise ValueError("a and p must have the same size") |
| 851 | |
| 852 | dependencies.append(p) |
| 853 | p = p.__dask_keys__()[0] |
| 854 | |
| 855 | if size is None: |
| 856 | size = () |
| 857 | |
| 858 | if axis != 0: |
| 859 | raise ValueError("axis must be 0 since a is one dimensional") |
| 860 | |
| 861 | chunks = normalize_chunks(chunks, size, dtype=np.float64) |
| 862 | if not replace and len(chunks[0]) > 1: |
no test coverage detected
searching dependent graphs…