MCPcopy
hub / github.com/pydata/xarray / least_squares

Function least_squares

xarray/compat/dask_array_ops.py:22–74  ·  view source on GitHub ↗
(lhs, rhs, rcond=None, skipna=False)

Source from the content-addressed store, hash-verified

20
21
22def least_squares(lhs, rhs, rcond=None, skipna=False):
23 import dask.array as da
24
25 # The trick here is that the core dimension is axis 0.
26 # All other dimensions need to be reshaped down to one axis for `lstsq`
27 # (which only accepts 2D input)
28 # and this needs to be undone after running `lstsq`
29 # The order of values in the reshaped axes is irrelevant.
30 # There are big gains to be had by simply reshaping the blocks on a blockwise
31 # basis, and then undoing that transform.
32 # We use a specific `reshape_blockwise` method in dask for this optimization
33 if rhs.ndim > 2:
34 out_shape = rhs.shape
35 reshape_chunks = rhs.chunks
36 rhs = reshape_blockwise(rhs, (rhs.shape[0], math.prod(rhs.shape[1:])))
37 else:
38 out_shape = None
39
40 lhs_da = da.from_array(lhs, chunks=(rhs.chunks[0], lhs.shape[1]))
41 if skipna:
42 added_dim = rhs.ndim == 1
43 if added_dim:
44 rhs = rhs.reshape(rhs.shape[0], 1)
45 results = da.apply_along_axis(
46 nputils._nanpolyfit_1d,
47 0,
48 rhs,
49 lhs_da,
50 dtype=float,
51 shape=(lhs.shape[1] + 1,),
52 rcond=rcond,
53 )
54 coeffs = results[:-1, ...]
55 residuals = results[-1, ...]
56 if added_dim:
57 coeffs = coeffs.reshape(coeffs.shape[0])
58 residuals = residuals.reshape(residuals.shape[0])
59 else:
60 # Residuals here are (1, 1) but should be (K,) as rhs is (N, K)
61 # See issue dask/dask#6516
62 coeffs, residuals, _, _ = da.linalg.lstsq(lhs_da, rhs)
63
64 if out_shape is not None:
65 coeffs = reshape_blockwise(
66 coeffs,
67 shape=(coeffs.shape[0], *out_shape[1:]),
68 chunks=((coeffs.shape[0],), *reshape_chunks[1:]),
69 )
70 residuals = reshape_blockwise(
71 residuals, shape=out_shape[1:], chunks=reshape_chunks[1:]
72 )
73
74 return coeffs, residuals
75
76
77def _fill_with_last_one(a, b):

Callers

nothing calls this directly

Calls 3

reshape_blockwiseFunction · 0.90
prodMethod · 0.45
from_arrayMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…