MCPcopy
hub / github.com/pydata/xarray / _cov_corr

Function _cov_corr

xarray/computation/computation.py:255–313  ·  view source on GitHub ↗

Internal method for xr.cov() and xr.corr() so only have to sanitize the input arrays once and we don't repeat code.

(
    da_a: T_DataArray,
    da_b: T_DataArray,
    weights: T_DataArray | None = None,
    dim: Dims = None,
    ddof: int = 0,
    method: Literal["cov", "corr"] | None = None,
)

Source from the content-addressed store, hash-verified

253
254
255def _cov_corr(
256 da_a: T_DataArray,
257 da_b: T_DataArray,
258 weights: T_DataArray | None = None,
259 dim: Dims = None,
260 ddof: int = 0,
261 method: Literal["cov", "corr"] | None = None,
262) -> T_DataArray:
263 """
264 Internal method for xr.cov() and xr.corr() so only have to
265 sanitize the input arrays once and we don't repeat code.
266 """
267 # 1. Broadcast the two arrays
268 da_a, da_b = align(da_a, da_b, join="inner", copy=False)
269
270 # 2. Ignore the nans
271 valid_values = da_a.notnull() & da_b.notnull()
272 da_a = da_a.where(valid_values)
273 da_b = da_b.where(valid_values)
274
275 # 3. Detrend along the given dim
276 if weights is not None:
277 demeaned_da_a = da_a - da_a.weighted(weights).mean(dim=dim)
278 demeaned_da_b = da_b - da_b.weighted(weights).mean(dim=dim)
279 else:
280 demeaned_da_a = da_a - da_a.mean(dim=dim)
281 demeaned_da_b = da_b - da_b.mean(dim=dim)
282
283 # 4. Compute covariance along the given dim
284 # N.B. `skipna=True` is required or auto-covariance is computed incorrectly. E.g.
285 # Try xr.cov(da,da) for da = xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"])
286 if weights is not None:
287 cov = (
288 (demeaned_da_a.conj() * demeaned_da_b)
289 .weighted(weights)
290 .mean(dim=dim, skipna=True)
291 )
292 else:
293 cov = (demeaned_da_a.conj() * demeaned_da_b).mean(dim=dim, skipna=True)
294
295 if method == "cov":
296 # Adjust covariance for degrees of freedom
297 valid_count = valid_values.sum(dim)
298 adjust = valid_count / (valid_count - ddof)
299 # I think the cast is required because of `T_DataArray` + `T_Xarray` (would be
300 # the same with `T_DatasetOrArray`)
301 # https://github.com/pydata/xarray/pull/8384#issuecomment-1784228026
302 return cast(T_DataArray, cov * adjust)
303
304 else:
305 # Compute std and corr
306 if weights is not None:
307 da_a_std = da_a.weighted(weights).std(dim=dim)
308 da_b_std = da_b.weighted(weights).std(dim=dim)
309 else:
310 da_a_std = da_a.std(dim=dim)
311 da_b_std = da_b.std(dim=dim)
312 corr = cov / (da_a_std * da_b_std)

Callers 2

covFunction · 0.85
corrFunction · 0.85

Calls 8

alignFunction · 0.90
notnullMethod · 0.45
whereMethod · 0.45
meanMethod · 0.45
weightedMethod · 0.45
conjMethod · 0.45
sumMethod · 0.45
stdMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…