MCPcopy
hub / github.com/dask/dask / solve_triangular

Function solve_triangular

dask/array/linalg.py:1114–1205  ·  view source on GitHub ↗

Solve the equation `a x = b` for `x`, assuming a is a triangular matrix. Parameters ---------- a : (M, M) array_like A triangular matrix b : (M,) or (M, N) array_like Right-hand side matrix in `a x = b` lower : bool, optional Use only data contained

(a, b, lower=False)

Source from the content-addressed store, hash-verified

1112
1113
1114def solve_triangular(a, b, lower=False):
1115 """
1116 Solve the equation `a x = b` for `x`, assuming a is a triangular matrix.
1117
1118 Parameters
1119 ----------
1120 a : (M, M) array_like
1121 A triangular matrix
1122 b : (M,) or (M, N) array_like
1123 Right-hand side matrix in `a x = b`
1124 lower : bool, optional
1125 Use only data contained in the lower triangle of `a`.
1126 Default is to use upper triangle.
1127
1128 Returns
1129 -------
1130 x : (M,) or (M, N) array
1131 Solution to the system `a x = b`. Shape of return matches `b`.
1132 """
1133
1134 if a.ndim != 2:
1135 raise ValueError("a must be 2 dimensional")
1136 if b.ndim <= 2:
1137 if a.shape[1] != b.shape[0]:
1138 raise ValueError("a.shape[1] and b.shape[0] must be equal")
1139 if a.chunks[1] != b.chunks[0]:
1140 msg = "a.chunks[1] and b.chunks[0] must be equal. Use .rechunk method to change the size of chunks."
1141 raise ValueError(msg)
1142 else:
1143 raise ValueError("b must be 1 or 2 dimensional")
1144
1145 vchunks = len(a.chunks[1])
1146 hchunks = 1 if b.ndim == 1 else len(b.chunks[1])
1147 token = tokenize(a, b, lower)
1148 name = f"solve-triangular-{token}"
1149
1150 # for internal calculation
1151 # (name, i, j, k, l) corresponds to a_ij.dot(b_kl)
1152 name_mdot = f"solve-tri-dot-{token}"
1153
1154 def _b_init(i, j):
1155 if b.ndim == 1:
1156 return b.name, i
1157 else:
1158 return b.name, i, j
1159
1160 def _key(i, j):
1161 if b.ndim == 1:
1162 return name, i
1163 else:
1164 return name, i, j
1165
1166 dsk = {}
1167 if lower:
1168 for i in range(vchunks):
1169 for j in range(hchunks):
1170 target = _b_init(i, j)
1171 if i > 0:

Callers 2

solveFunction · 0.85
lstsqFunction · 0.85

Calls 8

meta_from_arrayFunction · 0.90
array_safeFunction · 0.90
ArrayClass · 0.90
_b_initFunction · 0.85
_keyFunction · 0.85
_solve_triangular_lowerFunction · 0.85
from_collectionsMethod · 0.80
tokenizeFunction · 0.50

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…