MCPcopy Index your code
hub / github.com/dask/dask / _choice_validate_params

Function _choice_validate_params

dask/array/_array_expr/random.py:805–870  ·  view source on GitHub ↗
(state, a, size, replace, p, axis, chunks)

Source from the content-addressed store, hash-verified

803
804
805def _choice_validate_params(state, a, size, replace, p, axis, chunks):
806 dependencies = []
807 # Normalize and validate `a`
808 if isinstance(a, Integral):
809 if isinstance(state, Generator):
810 if state._backend_name == "cupy":
811 raise NotImplementedError(
812 "`choice` not supported for cupy-backed `Generator`."
813 )
814 meta = state._backend.random.default_rng().choice(1, size=(), p=None)
815 elif isinstance(state, RandomState):
816 # On windows the output dtype differs if p is provided or
817 # # absent, see https://github.com/numpy/numpy/issues/9867
818 dummy_p = state._backend.array([1]) if p is not None else p
819 meta = state._backend.random.RandomState().choice(1, size=(), p=dummy_p)
820 else:
821 raise ValueError("Unknown generator class")
822 len_a = a
823 if a < 0:
824 raise ValueError("a must be greater than 0")
825 else:
826 a = asarray(a)
827 a = a.rechunk(a.shape)
828 meta = a._meta
829 if a.ndim != 1:
830 raise ValueError("a must be one dimensional")
831 len_a = len(a)
832 dependencies.append(a)
833 a = a.__dask_keys__()[0]
834
835 # Normalize and validate `p`
836 if p is not None:
837 if not isinstance(p, Array):
838 # If p is not a dask array, first check the sum is close
839 # to 1 before converting.
840 p = asarray_safe(p, like=p)
841 if not np.isclose(p.sum(), 1, rtol=1e-7, atol=0):
842 raise ValueError("probabilities do not sum to 1")
843 p = asarray(p)
844 else:
845 p = p.rechunk(p.shape)
846
847 if p.ndim != 1:
848 raise ValueError("p must be one dimensional")
849 if len(p) != len_a:
850 raise ValueError("a and p must have the same size")
851
852 dependencies.append(p)
853 p = p.__dask_keys__()[0]
854
855 if size is None:
856 size = ()
857
858 if axis != 0:
859 raise ValueError("axis must be 0 since a is one dimensional")
860
861 chunks = normalize_chunks(chunks, size, dtype=np.float64)
862 if not replace and len(chunks[0]) > 1:

Callers 2

choiceMethod · 0.70
choiceMethod · 0.70

Calls 8

asarrayFunction · 0.90
asarray_safeFunction · 0.90
normalize_chunksFunction · 0.90
choiceMethod · 0.45
RandomStateMethod · 0.45
rechunkMethod · 0.45
__dask_keys__Method · 0.45
sumMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…