MCPcopy
hub / github.com/dask/dask / _choice_validate_params

Function _choice_validate_params

dask/array/random.py:828–895  ·  view source on GitHub ↗
(state, a, size, replace, p, axis, chunks)

Source from the content-addressed store, hash-verified

826
827
828def _choice_validate_params(state, a, size, replace, p, axis, chunks):
829 dependencies = []
830 # Normalize and validate `a`
831 if isinstance(a, Integral):
832 if isinstance(state, Generator):
833 if state._backend_name == "cupy":
834 raise NotImplementedError(
835 "`choice` not supported for cupy-backed `Generator`."
836 )
837 meta = state._backend.random.default_rng().choice(1, size=(), p=None)
838 elif isinstance(state, RandomState):
839 # On windows the output dtype differs if p is provided or
840 # # absent, see https://github.com/numpy/numpy/issues/9867
841 dummy_p = state._backend.array([1]) if p is not None else p
842 meta = state._backend.random.RandomState().choice(1, size=(), p=dummy_p)
843 else:
844 raise ValueError("Unknown generator class")
845 len_a = a
846 if a < 0:
847 raise ValueError("a must be greater than 0")
848 else:
849 a = asarray(a)
850 a = a.rechunk(a.shape)
851 meta = a._meta
852 if a.ndim != 1:
853 raise ValueError("a must be one dimensional")
854 len_a = len(a)
855 dependencies.append(a)
856 a = TaskRef(a.__dask_keys__()[0])
857
858 # Normalize and validate `p`
859 if p is not None:
860 if not isinstance(p, Array):
861 # If p is not a dask array, first check the sum is close
862 # to 1 before converting.
863 p = asarray_safe(p, like=p)
864 if not np.isclose(p.sum(), 1, rtol=1e-7, atol=0):
865 raise ValueError("probabilities do not sum to 1")
866 p = asarray(p)
867 else:
868 p = p.rechunk(p.shape)
869
870 if p.ndim != 1:
871 raise ValueError("p must be one dimensional")
872 if len(p) != len_a:
873 raise ValueError("a and p must have the same size")
874
875 dependencies.append(p)
876 p = TaskRef(p.__dask_keys__()[0])
877
878 if size is None:
879 size = ()
880 elif not isinstance(size, (tuple, list)):
881 size = (size,)
882
883 if axis != 0:
884 raise ValueError("axis must be 0 since a is one dimensional")
885

Callers 2

choiceMethod · 0.70
choiceMethod · 0.70

Calls 9

asarrayFunction · 0.90
TaskRefClass · 0.90
asarray_safeFunction · 0.90
normalize_chunksFunction · 0.90
choiceMethod · 0.45
RandomStateMethod · 0.45
rechunkMethod · 0.45
__dask_keys__Method · 0.45
sumMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…