MCPcopy Index your code
hub / github.com/scikit-learn/scikit-learn / _average

Function _average

sklearn/utils/_array_api.py:789–852  ·  view source on GitHub ↗

Partial port of np.average to support the Array API. It does a best effort at mimicking the return dtype rule described at https://numpy.org/doc/stable/reference/generated/numpy.average.html but only for the common cases needed in scikit-learn.

(a, axis=None, weights=None, normalize=True, xp=None)

Source from the content-addressed store, hash-verified

787
788
789def _average(a, axis=None, weights=None, normalize=True, xp=None):
790 """Partial port of np.average to support the Array API.
791
792 It does a best effort at mimicking the return dtype rule described at
793 https://numpy.org/doc/stable/reference/generated/numpy.average.html but
794 only for the common cases needed in scikit-learn.
795 """
796 xp, _, device_ = get_namespace_and_device(a, weights, xp=xp)
797
798 if _is_numpy_namespace(xp):
799 if normalize:
800 return xp.asarray(numpy.average(a, axis=axis, weights=weights))
801 elif axis is None and weights is not None:
802 return xp.asarray(numpy.dot(a, weights))
803
804 a = xp.asarray(a, device=device_)
805 if weights is not None:
806 weights = xp.asarray(weights, device=device_)
807
808 if weights is not None and a.shape != weights.shape:
809 if axis is None:
810 raise TypeError(
811 f"Axis must be specified when the shape of a {tuple(a.shape)} and "
812 f"weights {tuple(weights.shape)} differ."
813 )
814
815 if tuple(weights.shape) != (a.shape[axis],):
816 raise ValueError(
817 f"Shape of weights weights.shape={tuple(weights.shape)} must be "
818 f"consistent with a.shape={tuple(a.shape)} and {axis=}."
819 )
820
821 # If weights are 1D, add singleton dimensions for broadcasting
822 shape = [1] * a.ndim
823 shape[axis] = a.shape[axis]
824 weights = xp.reshape(weights, tuple(shape))
825
826 if xp.isdtype(a.dtype, "complex floating"):
827 raise NotImplementedError(
828 "Complex floating point values are not supported by average."
829 )
830 if weights is not None and xp.isdtype(weights.dtype, "complex floating"):
831 raise NotImplementedError(
832 "Complex floating point values are not supported by average."
833 )
834
835 output_dtype = _find_matching_floating_dtype(a, weights, xp=xp)
836 a = xp.astype(a, output_dtype)
837
838 if weights is None:
839 return (xp.mean if normalize else xp.sum)(a, axis=axis)
840
841 weights = xp.astype(weights, output_dtype)
842
843 sum_ = xp.sum(xp.multiply(a, weights), axis=axis)
844
845 if not normalize:
846 return sum_

Callers 15

_update_mean_varianceMethod · 0.90
_nanaverageFunction · 0.90
test_averageFunction · 0.90
_average_binary_scoreFunction · 0.90
mean_absolute_errorFunction · 0.90
mean_pinball_lossFunction · 0.90
mean_squared_errorFunction · 0.90
root_mean_squared_errorFunction · 0.90
median_absolute_errorFunction · 0.90

Calls 4

get_namespace_and_deviceFunction · 0.85
_is_numpy_namespaceFunction · 0.85
multiplyMethod · 0.80

Used in the wild real call sites across dependent graphs

searching dependent graphs…