(state, a, size, replace, p, axis, chunks)
| 826 | |
| 827 | |
| 828 | def _choice_validate_params(state, a, size, replace, p, axis, chunks): |
| 829 | dependencies = [] |
| 830 | # Normalize and validate `a` |
| 831 | if isinstance(a, Integral): |
| 832 | if isinstance(state, Generator): |
| 833 | if state._backend_name == "cupy": |
| 834 | raise NotImplementedError( |
| 835 | "`choice` not supported for cupy-backed `Generator`." |
| 836 | ) |
| 837 | meta = state._backend.random.default_rng().choice(1, size=(), p=None) |
| 838 | elif isinstance(state, RandomState): |
| 839 | # On windows the output dtype differs if p is provided or |
| 840 | # # absent, see https://github.com/numpy/numpy/issues/9867 |
| 841 | dummy_p = state._backend.array([1]) if p is not None else p |
| 842 | meta = state._backend.random.RandomState().choice(1, size=(), p=dummy_p) |
| 843 | else: |
| 844 | raise ValueError("Unknown generator class") |
| 845 | len_a = a |
| 846 | if a < 0: |
| 847 | raise ValueError("a must be greater than 0") |
| 848 | else: |
| 849 | a = asarray(a) |
| 850 | a = a.rechunk(a.shape) |
| 851 | meta = a._meta |
| 852 | if a.ndim != 1: |
| 853 | raise ValueError("a must be one dimensional") |
| 854 | len_a = len(a) |
| 855 | dependencies.append(a) |
| 856 | a = TaskRef(a.__dask_keys__()[0]) |
| 857 | |
| 858 | # Normalize and validate `p` |
| 859 | if p is not None: |
| 860 | if not isinstance(p, Array): |
| 861 | # If p is not a dask array, first check the sum is close |
| 862 | # to 1 before converting. |
| 863 | p = asarray_safe(p, like=p) |
| 864 | if not np.isclose(p.sum(), 1, rtol=1e-7, atol=0): |
| 865 | raise ValueError("probabilities do not sum to 1") |
| 866 | p = asarray(p) |
| 867 | else: |
| 868 | p = p.rechunk(p.shape) |
| 869 | |
| 870 | if p.ndim != 1: |
| 871 | raise ValueError("p must be one dimensional") |
| 872 | if len(p) != len_a: |
| 873 | raise ValueError("a and p must have the same size") |
| 874 | |
| 875 | dependencies.append(p) |
| 876 | p = TaskRef(p.__dask_keys__()[0]) |
| 877 | |
| 878 | if size is None: |
| 879 | size = () |
| 880 | elif not isinstance(size, (tuple, list)): |
| 881 | size = (size,) |
| 882 | |
| 883 | if axis != 0: |
| 884 | raise ValueError("axis must be 0 since a is one dimensional") |
| 885 |
no test coverage detected
searching dependent graphs…