(
a, axis=None, weights=None, returned=False, is_masked=False, keepdims=False
)
| 2505 | |
| 2506 | |
| 2507 | def _average( |
| 2508 | a, axis=None, weights=None, returned=False, is_masked=False, keepdims=False |
| 2509 | ): |
| 2510 | # This was minimally modified from numpy.average |
| 2511 | # See numpy license at https://github.com/numpy/numpy/blob/master/LICENSE.txt |
| 2512 | # or NUMPY_LICENSE.txt within this directory |
| 2513 | # Wrapper used by da.average or da.ma.average. |
| 2514 | a = asanyarray(a) |
| 2515 | |
| 2516 | if weights is None: |
| 2517 | avg = a.mean(axis, keepdims=keepdims) |
| 2518 | scl = avg.dtype.type(a.size / avg.size) |
| 2519 | else: |
| 2520 | wgt = asanyarray(weights) |
| 2521 | |
| 2522 | if issubclass(a.dtype.type, (np.integer, np.bool_)): |
| 2523 | result_dtype = result_type(a.dtype, wgt.dtype, "f8") |
| 2524 | else: |
| 2525 | result_dtype = result_type(a.dtype, wgt.dtype) |
| 2526 | |
| 2527 | # Sanity checks |
| 2528 | if a.shape != wgt.shape: |
| 2529 | if axis is None: |
| 2530 | raise TypeError( |
| 2531 | "Axis must be specified when shapes of a and weights differ." |
| 2532 | ) |
| 2533 | if wgt.ndim != 1: |
| 2534 | raise TypeError( |
| 2535 | "1D weights expected when shapes of a and weights differ." |
| 2536 | ) |
| 2537 | if wgt.shape[0] != a.shape[axis]: |
| 2538 | raise ValueError( |
| 2539 | "Length of weights not compatible with specified axis." |
| 2540 | ) |
| 2541 | |
| 2542 | # setup wgt to broadcast along axis |
| 2543 | wgt = broadcast_to(wgt, (a.ndim - 1) * (1,) + wgt.shape) |
| 2544 | wgt = wgt.swapaxes(-1, axis) |
| 2545 | if is_masked: |
| 2546 | from dask.array.ma import getmaskarray |
| 2547 | |
| 2548 | wgt = wgt * (~getmaskarray(a)) |
| 2549 | scl = wgt.sum(axis=axis, dtype=result_dtype, keepdims=keepdims) |
| 2550 | avg = multiply(a, wgt, dtype=result_dtype).sum(axis, keepdims=keepdims) / scl |
| 2551 | |
| 2552 | if returned: |
| 2553 | if scl.shape != avg.shape: |
| 2554 | scl = broadcast_to(scl, avg.shape).copy() |
| 2555 | return avg, scl |
| 2556 | else: |
| 2557 | return avg |
| 2558 | |
| 2559 | |
| 2560 | @derived_from(np) |
no test coverage detected
searching dependent graphs…