MCPcopy Index your code
hub / github.com/dask/dask / _tree_reduce

Function _tree_reduce

dask/array/_array_expr/_reductions.py:202–259  ·  view source on GitHub ↗

Perform the tree reduction step of a reduction. Lower level, users should use ``reduction`` or ``arg_reduction`` directly.

(
    x,
    aggregate,
    axis,
    keepdims,
    dtype,
    split_every=None,
    combine=None,
    name=None,
    concatenate=True,
    reduced_meta=None,
)

Source from the content-addressed store, hash-verified

200
201
202def _tree_reduce(
203 x,
204 aggregate,
205 axis,
206 keepdims,
207 dtype,
208 split_every=None,
209 combine=None,
210 name=None,
211 concatenate=True,
212 reduced_meta=None,
213):
214 """Perform the tree reduction step of a reduction.
215
216 Lower level, users should use ``reduction`` or ``arg_reduction`` directly.
217 """
218 # Normalize split_every
219 split_every = split_every or config.get("split_every", 16)
220 if isinstance(split_every, dict):
221 split_every = {k: split_every.get(k, 2) for k in axis}
222 elif isinstance(split_every, Integral):
223 n = builtins.max(int(split_every ** (1 / (len(axis) or 1))), 2)
224 split_every = dict.fromkeys(axis, n)
225 else:
226 raise ValueError("split_every must be a int or a dict")
227
228 # Reduce across intermediates
229 depth = 1
230 for i, n in enumerate(x.numblocks):
231 if i in split_every and split_every[i] != 1:
232 depth = int(builtins.max(depth, math.ceil(math.log(n, split_every[i]))))
233 func = partial(combine or aggregate, axis=axis, keepdims=True)
234 if concatenate:
235 func = compose(func, partial(_concatenate2, axes=sorted(axis)))
236 for _ in range(depth - 1):
237 x = PartialReduce(
238 x,
239 func,
240 split_every,
241 True,
242 dtype=dtype,
243 name=(name or funcname(combine or aggregate)) + "-partial",
244 reduced_meta=reduced_meta,
245 )
246 func = partial(aggregate, axis=axis, keepdims=keepdims)
247 if concatenate:
248 func = compose(func, partial(_concatenate2, axes=sorted(axis)))
249 return new_collection(
250 PartialReduce(
251 x,
252 func,
253 split_every,
254 keepdims=keepdims,
255 dtype=dtype,
256 name=(name or funcname(aggregate)) + "-aggregate",
257 reduced_meta=reduced_meta,
258 )
259 )

Callers 1

reductionFunction · 0.70

Calls 5

funcnameFunction · 0.90
new_collectionFunction · 0.90
PartialReduceClass · 0.85
getMethod · 0.45
maxMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…