MCPcopy
hub / github.com/dask/dask / _average

Function _average

dask/array/routines.py:2507–2557  ·  view source on GitHub ↗
(
    a, axis=None, weights=None, returned=False, is_masked=False, keepdims=False
)

Source from the content-addressed store, hash-verified

2505
2506
2507def _average(
2508 a, axis=None, weights=None, returned=False, is_masked=False, keepdims=False
2509):
2510 # This was minimally modified from numpy.average
2511 # See numpy license at https://github.com/numpy/numpy/blob/master/LICENSE.txt
2512 # or NUMPY_LICENSE.txt within this directory
2513 # Wrapper used by da.average or da.ma.average.
2514 a = asanyarray(a)
2515
2516 if weights is None:
2517 avg = a.mean(axis, keepdims=keepdims)
2518 scl = avg.dtype.type(a.size / avg.size)
2519 else:
2520 wgt = asanyarray(weights)
2521
2522 if issubclass(a.dtype.type, (np.integer, np.bool_)):
2523 result_dtype = result_type(a.dtype, wgt.dtype, "f8")
2524 else:
2525 result_dtype = result_type(a.dtype, wgt.dtype)
2526
2527 # Sanity checks
2528 if a.shape != wgt.shape:
2529 if axis is None:
2530 raise TypeError(
2531 "Axis must be specified when shapes of a and weights differ."
2532 )
2533 if wgt.ndim != 1:
2534 raise TypeError(
2535 "1D weights expected when shapes of a and weights differ."
2536 )
2537 if wgt.shape[0] != a.shape[axis]:
2538 raise ValueError(
2539 "Length of weights not compatible with specified axis."
2540 )
2541
2542 # setup wgt to broadcast along axis
2543 wgt = broadcast_to(wgt, (a.ndim - 1) * (1,) + wgt.shape)
2544 wgt = wgt.swapaxes(-1, axis)
2545 if is_masked:
2546 from dask.array.ma import getmaskarray
2547
2548 wgt = wgt * (~getmaskarray(a))
2549 scl = wgt.sum(axis=axis, dtype=result_dtype, keepdims=keepdims)
2550 avg = multiply(a, wgt, dtype=result_dtype).sum(axis, keepdims=keepdims) / scl
2551
2552 if returned:
2553 if scl.shape != avg.shape:
2554 scl = broadcast_to(scl, avg.shape).copy()
2555 return avg, scl
2556 else:
2557 return avg
2558
2559
2560@derived_from(np)

Callers 2

averageFunction · 0.90
averageFunction · 0.85

Calls 8

asanyarrayFunction · 0.90
broadcast_toFunction · 0.90
getmaskarrayFunction · 0.90
result_typeFunction · 0.85
swapaxesMethod · 0.80
meanMethod · 0.45
sumMethod · 0.45
copyMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…