Perform the tree reduction step of a reduction. Lower level, users should use ``reduction`` or ``arg_reduction`` directly.
(
x,
aggregate,
axis,
keepdims,
dtype,
split_every=None,
combine=None,
name=None,
concatenate=True,
reduced_meta=None,
)
| 213 | |
| 214 | |
| 215 | def _tree_reduce( |
| 216 | x, |
| 217 | aggregate, |
| 218 | axis, |
| 219 | keepdims, |
| 220 | dtype, |
| 221 | split_every=None, |
| 222 | combine=None, |
| 223 | name=None, |
| 224 | concatenate=True, |
| 225 | reduced_meta=None, |
| 226 | ): |
| 227 | """Perform the tree reduction step of a reduction. |
| 228 | |
| 229 | Lower level, users should use ``reduction`` or ``arg_reduction`` directly. |
| 230 | """ |
| 231 | # Normalize split_every |
| 232 | split_every = split_every or config.get("split_every", 4) |
| 233 | if isinstance(split_every, dict): |
| 234 | split_every = {k: split_every.get(k, 2) for k in axis} |
| 235 | elif isinstance(split_every, Integral): |
| 236 | n = builtins.max(int(split_every ** (1 / (len(axis) or 1))), 2) |
| 237 | split_every = dict.fromkeys(axis, n) |
| 238 | else: |
| 239 | raise ValueError("split_every must be a int or a dict") |
| 240 | |
| 241 | # Reduce across intermediates |
| 242 | depth = 1 |
| 243 | for i, n in enumerate(x.numblocks): |
| 244 | if i in split_every and split_every[i] != 1: |
| 245 | depth = int(builtins.max(depth, math.ceil(math.log(n, split_every[i])))) |
| 246 | func = partial(combine or aggregate, axis=axis, keepdims=True) |
| 247 | if concatenate: |
| 248 | func = compose(func, partial(_concatenate2, axes=sorted(axis))) |
| 249 | for _ in range(depth - 1): |
| 250 | x = partial_reduce( |
| 251 | func, |
| 252 | x, |
| 253 | split_every, |
| 254 | True, |
| 255 | dtype=dtype, |
| 256 | name=(name or funcname(combine or aggregate)) + "-partial", |
| 257 | reduced_meta=reduced_meta, |
| 258 | ) |
| 259 | func = partial(aggregate, axis=axis, keepdims=keepdims) |
| 260 | if concatenate: |
| 261 | func = compose(func, partial(_concatenate2, axes=sorted(axis))) |
| 262 | return partial_reduce( |
| 263 | func, |
| 264 | x, |
| 265 | split_every, |
| 266 | keepdims=keepdims, |
| 267 | dtype=dtype, |
| 268 | name=(name or funcname(aggregate)) + "-aggregate", |
| 269 | reduced_meta=reduced_meta, |
| 270 | ) |
| 271 | |
| 272 |
no test coverage detected
searching dependent graphs…