MCPcopy
hub / github.com/dask/dask / arg_reduction

Function arg_reduction

dask/array/reductions.py:844–918  ·  view source on GitHub ↗

Generic function for argreduction. Parameters ---------- x : Array chunk : callable Partialed ``arg_chunk``. combine : callable Partialed ``arg_combine``. agg : callable Partialed ``arg_agg``. axis : int, optional split_every : int or dict, op

(
    x, chunk, combine, agg, axis=None, keepdims=False, split_every=None, out=None
)

Source from the content-addressed store, hash-verified

842
843
844def arg_reduction(
845 x, chunk, combine, agg, axis=None, keepdims=False, split_every=None, out=None
846):
847 """Generic function for argreduction.
848
849 Parameters
850 ----------
851 x : Array
852 chunk : callable
853 Partialed ``arg_chunk``.
854 combine : callable
855 Partialed ``arg_combine``.
856 agg : callable
857 Partialed ``arg_agg``.
858 axis : int, optional
859 split_every : int or dict, optional
860 """
861 if axis is None:
862 axis = tuple(range(x.ndim))
863 ravel = True
864 elif isinstance(axis, Integral):
865 axis = validate_axis(axis, x.ndim)
866 axis = (axis,)
867 ravel = x.ndim == 1
868 else:
869 raise TypeError(f"axis must be either `None` or int, got '{axis}'")
870
871 for ax in axis:
872 chunks = x.chunks[ax]
873 if len(chunks) > 1 and np.isnan(chunks).any():
874 raise ValueError(
875 "Arg-reductions do not work with arrays that have "
876 "unknown chunksizes. At some point in your computation "
877 "this array lost chunking information.\n\n"
878 "A possible solution is with \n"
879 " x.compute_chunk_sizes()"
880 )
881
882 # Map chunk across all blocks
883 name = f"arg-reduce-{tokenize(axis, x, chunk, combine, split_every)}"
884 old = x.name
885 keys = list(product(*map(range, x.numblocks)))
886 offsets = list(product(*(accumulate(operator.add, bd[:-1], 0) for bd in x.chunks)))
887 if ravel:
888 offset_info = zip(offsets, repeat(x.shape))
889 else:
890 offset_info = pluck(axis[0], offsets)
891
892 chunks = tuple((1,) * len(c) if i in axis else c for (i, c) in enumerate(x.chunks))
893 dsk = {
894 (name,) + k: (chunk, (old,) + k, axis, off)
895 for (k, off) in zip(keys, offset_info)
896 }
897
898 dtype = np.argmin(asarray_safe([1], like=meta_from_array(x)))
899 meta = None
900 if is_arraylike(dtype):
901 # This case occurs on non-NumPy types (e.g., CuPy), where the returned

Callers 4

argmaxFunction · 0.85
argminFunction · 0.85
nanargmaxFunction · 0.85
nanargminFunction · 0.85

Calls 12

validate_axisFunction · 0.90
asarray_safeFunction · 0.90
meta_from_arrayFunction · 0.90
is_arraylikeFunction · 0.90
ArrayClass · 0.90
_tree_reduceFunction · 0.90
handle_outFunction · 0.90
argminMethod · 0.80
from_collectionsMethod · 0.80
repeatFunction · 0.70
tokenizeFunction · 0.50
anyMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…