A wrapper for functions that don't provide keepdims to ensure that they do.
(a_callable)
| 14 | |
| 15 | |
| 16 | def keepdims_wrapper(a_callable): |
| 17 | """ |
| 18 | A wrapper for functions that don't provide keepdims to ensure that they do. |
| 19 | """ |
| 20 | |
| 21 | @wraps(a_callable) |
| 22 | def keepdims_wrapped_callable(x, axis=None, keepdims=None, *args, **kwargs): |
| 23 | r = a_callable(x, *args, axis=axis, **kwargs) |
| 24 | |
| 25 | if not keepdims: |
| 26 | return r |
| 27 | |
| 28 | axes = axis |
| 29 | |
| 30 | if axes is None: |
| 31 | axes = range(x.ndim) |
| 32 | |
| 33 | if not isinstance(axes, (Container, Iterable, Sequence)): |
| 34 | axes = [axes] |
| 35 | |
| 36 | r_slice = tuple() |
| 37 | for each_axis in range(x.ndim): |
| 38 | if each_axis in axes: |
| 39 | r_slice += (None,) |
| 40 | else: |
| 41 | r_slice += (slice(None),) |
| 42 | |
| 43 | r = r[r_slice] |
| 44 | |
| 45 | return r |
| 46 | |
| 47 | return keepdims_wrapped_callable |
| 48 | |
| 49 | |
| 50 | # Wrap NumPy functions to ensure they provide keepdims. |
no outgoing calls
searching dependent graphs…