MCPcopy
hub / github.com/dask/dask / _cholesky

Function _cholesky

dask/array/linalg.py:1321–1399  ·  view source on GitHub ↗

Private function to perform Cholesky decomposition, which returns both lower and upper triangulars.

(a)

Source from the content-addressed store, hash-verified

1319
1320
1321def _cholesky(a):
1322 """
1323 Private function to perform Cholesky decomposition, which returns both
1324 lower and upper triangulars.
1325 """
1326
1327 if a.ndim != 2:
1328 raise ValueError("Dimension must be 2 to perform cholesky decomposition")
1329
1330 xdim, ydim = a.shape
1331 if xdim != ydim:
1332 raise ValueError(
1333 "Input must be a square matrix to perform cholesky decomposition"
1334 )
1335 if len(set(a.chunks[0] + a.chunks[1])) != 1:
1336 msg = (
1337 "All chunks must be a square matrix to perform cholesky decomposition. "
1338 "Use .rechunk method to change the size of chunks."
1339 )
1340 raise ValueError(msg)
1341
1342 vdim = len(a.chunks[0])
1343 hdim = len(a.chunks[1])
1344
1345 token = tokenize(a)
1346 name = f"cholesky-{token}"
1347
1348 # (name_lt_dot, i, j, k, l) corresponds to l_ij.dot(l_kl.T)
1349 name_lt_dot = f"cholesky-lt-dot-{token}"
1350 # because transposed results are needed for calculation,
1351 # we can build graph for upper triangular simultaneously
1352 name_upper = f"cholesky-upper-{token}"
1353
1354 # calculates lower triangulars because subscriptions get simpler
1355 dsk = {}
1356 for i in range(vdim):
1357 for j in range(hdim):
1358 if i < j:
1359 dsk[name, i, j] = (
1360 partial(np.zeros_like, shape=(a.chunks[0][i], a.chunks[1][j])),
1361 meta_from_array(a),
1362 )
1363 dsk[name_upper, j, i] = (name, i, j)
1364 elif i == j:
1365 target = (a.name, i, j)
1366 if i > 0:
1367 prevs = []
1368 for p in range(i):
1369 prev = name_lt_dot, i, p, i, p
1370 dsk[prev] = (np.dot, (name, i, p), (name_upper, p, i))
1371 prevs.append(prev)
1372 target = (operator.sub, target, (sum, prevs))
1373 dsk[name, i, i] = (_cholesky_lower, target)
1374 dsk[name_upper, i, i] = (_conj_transpose, (name, i, i))
1375 else:
1376 # solving x.dot(L11.T) = (A21 - L20.dot(L10.T)) is equal to
1377 # L11.dot(x.T) = A21.T - L10.dot(L20.T)
1378 # L11.dot(x.T) = A12 - L10.dot(L02)

Callers 2

solveFunction · 0.85
choleskyFunction · 0.85

Calls 6

meta_from_arrayFunction · 0.90
array_safeFunction · 0.90
ArrayClass · 0.90
setClass · 0.85
from_collectionsMethod · 0.80
tokenizeFunction · 0.50

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…