Perform the tree reduction step of a reduction. Lower level, users should use ``reduction`` or ``arg_reduction`` directly.
(
x,
aggregate,
axis,
keepdims,
dtype,
split_every=None,
combine=None,
name=None,
concatenate=True,
reduced_meta=None,
)
| 200 | |
| 201 | |
| 202 | def _tree_reduce( |
| 203 | x, |
| 204 | aggregate, |
| 205 | axis, |
| 206 | keepdims, |
| 207 | dtype, |
| 208 | split_every=None, |
| 209 | combine=None, |
| 210 | name=None, |
| 211 | concatenate=True, |
| 212 | reduced_meta=None, |
| 213 | ): |
| 214 | """Perform the tree reduction step of a reduction. |
| 215 | |
| 216 | Lower level, users should use ``reduction`` or ``arg_reduction`` directly. |
| 217 | """ |
| 218 | # Normalize split_every |
| 219 | split_every = split_every or config.get("split_every", 16) |
| 220 | if isinstance(split_every, dict): |
| 221 | split_every = {k: split_every.get(k, 2) for k in axis} |
| 222 | elif isinstance(split_every, Integral): |
| 223 | n = builtins.max(int(split_every ** (1 / (len(axis) or 1))), 2) |
| 224 | split_every = dict.fromkeys(axis, n) |
| 225 | else: |
| 226 | raise ValueError("split_every must be a int or a dict") |
| 227 | |
| 228 | # Reduce across intermediates |
| 229 | depth = 1 |
| 230 | for i, n in enumerate(x.numblocks): |
| 231 | if i in split_every and split_every[i] != 1: |
| 232 | depth = int(builtins.max(depth, math.ceil(math.log(n, split_every[i])))) |
| 233 | func = partial(combine or aggregate, axis=axis, keepdims=True) |
| 234 | if concatenate: |
| 235 | func = compose(func, partial(_concatenate2, axes=sorted(axis))) |
| 236 | for _ in range(depth - 1): |
| 237 | x = PartialReduce( |
| 238 | x, |
| 239 | func, |
| 240 | split_every, |
| 241 | True, |
| 242 | dtype=dtype, |
| 243 | name=(name or funcname(combine or aggregate)) + "-partial", |
| 244 | reduced_meta=reduced_meta, |
| 245 | ) |
| 246 | func = partial(aggregate, axis=axis, keepdims=keepdims) |
| 247 | if concatenate: |
| 248 | func = compose(func, partial(_concatenate2, axes=sorted(axis))) |
| 249 | return new_collection( |
| 250 | PartialReduce( |
| 251 | x, |
| 252 | func, |
| 253 | split_every, |
| 254 | keepdims=keepdims, |
| 255 | dtype=dtype, |
| 256 | name=(name or funcname(aggregate)) + "-aggregate", |
| 257 | reduced_meta=reduced_meta, |
| 258 | ) |
| 259 | ) |
no test coverage detected
searching dependent graphs…