MCPcopy Index your code
hub / github.com/dask/dask / _tree_reduce

Function _tree_reduce

dask/array/_reductions_generic.py:215–270  ·  view source on GitHub ↗

Perform the tree reduction step of a reduction. Lower level, users should use ``reduction`` or ``arg_reduction`` directly.

(
    x,
    aggregate,
    axis,
    keepdims,
    dtype,
    split_every=None,
    combine=None,
    name=None,
    concatenate=True,
    reduced_meta=None,
)

Source from the content-addressed store, hash-verified

213
214
215def _tree_reduce(
216 x,
217 aggregate,
218 axis,
219 keepdims,
220 dtype,
221 split_every=None,
222 combine=None,
223 name=None,
224 concatenate=True,
225 reduced_meta=None,
226):
227 """Perform the tree reduction step of a reduction.
228
229 Lower level, users should use ``reduction`` or ``arg_reduction`` directly.
230 """
231 # Normalize split_every
232 split_every = split_every or config.get("split_every", 4)
233 if isinstance(split_every, dict):
234 split_every = {k: split_every.get(k, 2) for k in axis}
235 elif isinstance(split_every, Integral):
236 n = builtins.max(int(split_every ** (1 / (len(axis) or 1))), 2)
237 split_every = dict.fromkeys(axis, n)
238 else:
239 raise ValueError("split_every must be a int or a dict")
240
241 # Reduce across intermediates
242 depth = 1
243 for i, n in enumerate(x.numblocks):
244 if i in split_every and split_every[i] != 1:
245 depth = int(builtins.max(depth, math.ceil(math.log(n, split_every[i]))))
246 func = partial(combine or aggregate, axis=axis, keepdims=True)
247 if concatenate:
248 func = compose(func, partial(_concatenate2, axes=sorted(axis)))
249 for _ in range(depth - 1):
250 x = partial_reduce(
251 func,
252 x,
253 split_every,
254 True,
255 dtype=dtype,
256 name=(name or funcname(combine or aggregate)) + "-partial",
257 reduced_meta=reduced_meta,
258 )
259 func = partial(aggregate, axis=axis, keepdims=keepdims)
260 if concatenate:
261 func = compose(func, partial(_concatenate2, axes=sorted(axis)))
262 return partial_reduce(
263 func,
264 x,
265 split_every,
266 keepdims=keepdims,
267 dtype=dtype,
268 name=(name or funcname(aggregate)) + "-aggregate",
269 reduced_meta=reduced_meta,
270 )
271
272

Callers 3

arg_reductionFunction · 0.90
bincountFunction · 0.90
reductionFunction · 0.70

Calls 4

funcnameFunction · 0.90
partial_reduceFunction · 0.85
getMethod · 0.45
maxMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…