Partial port of np.average to support the Array API. It does a best effort at mimicking the return dtype rule described at https://numpy.org/doc/stable/reference/generated/numpy.average.html but only for the common cases needed in scikit-learn.
(a, axis=None, weights=None, normalize=True, xp=None)
| 787 | |
| 788 | |
| 789 | def _average(a, axis=None, weights=None, normalize=True, xp=None): |
| 790 | """Partial port of np.average to support the Array API. |
| 791 | |
| 792 | It does a best effort at mimicking the return dtype rule described at |
| 793 | https://numpy.org/doc/stable/reference/generated/numpy.average.html but |
| 794 | only for the common cases needed in scikit-learn. |
| 795 | """ |
| 796 | xp, _, device_ = get_namespace_and_device(a, weights, xp=xp) |
| 797 | |
| 798 | if _is_numpy_namespace(xp): |
| 799 | if normalize: |
| 800 | return xp.asarray(numpy.average(a, axis=axis, weights=weights)) |
| 801 | elif axis is None and weights is not None: |
| 802 | return xp.asarray(numpy.dot(a, weights)) |
| 803 | |
| 804 | a = xp.asarray(a, device=device_) |
| 805 | if weights is not None: |
| 806 | weights = xp.asarray(weights, device=device_) |
| 807 | |
| 808 | if weights is not None and a.shape != weights.shape: |
| 809 | if axis is None: |
| 810 | raise TypeError( |
| 811 | f"Axis must be specified when the shape of a {tuple(a.shape)} and " |
| 812 | f"weights {tuple(weights.shape)} differ." |
| 813 | ) |
| 814 | |
| 815 | if tuple(weights.shape) != (a.shape[axis],): |
| 816 | raise ValueError( |
| 817 | f"Shape of weights weights.shape={tuple(weights.shape)} must be " |
| 818 | f"consistent with a.shape={tuple(a.shape)} and {axis=}." |
| 819 | ) |
| 820 | |
| 821 | # If weights are 1D, add singleton dimensions for broadcasting |
| 822 | shape = [1] * a.ndim |
| 823 | shape[axis] = a.shape[axis] |
| 824 | weights = xp.reshape(weights, tuple(shape)) |
| 825 | |
| 826 | if xp.isdtype(a.dtype, "complex floating"): |
| 827 | raise NotImplementedError( |
| 828 | "Complex floating point values are not supported by average." |
| 829 | ) |
| 830 | if weights is not None and xp.isdtype(weights.dtype, "complex floating"): |
| 831 | raise NotImplementedError( |
| 832 | "Complex floating point values are not supported by average." |
| 833 | ) |
| 834 | |
| 835 | output_dtype = _find_matching_floating_dtype(a, weights, xp=xp) |
| 836 | a = xp.astype(a, output_dtype) |
| 837 | |
| 838 | if weights is None: |
| 839 | return (xp.mean if normalize else xp.sum)(a, axis=axis) |
| 840 | |
| 841 | weights = xp.astype(weights, output_dtype) |
| 842 | |
| 843 | sum_ = xp.sum(xp.multiply(a, weights), axis=axis) |
| 844 | |
| 845 | if not normalize: |
| 846 | return sum_ |
searching dependent graphs…