MCPcopy
hub / github.com/scikit-learn/scikit-learn / safe_sparse_dot

Function safe_sparse_dot

sklearn/utils/extmath.py:166–238  ·  view source on GitHub ↗

Dot product that handle the sparse matrix case correctly. Parameters ---------- a : {ndarray, sparse matrix} First operand of the dot product. b : {ndarray, sparse matrix} Second operand of the dot product. dense_output : bool, default=False When False, `

(a, b, *, dense_output=False)

Source from the content-addressed store, hash-verified

164
165
166def safe_sparse_dot(a, b, *, dense_output=False):
167 """Dot product that handle the sparse matrix case correctly.
168
169 Parameters
170 ----------
171 a : {ndarray, sparse matrix}
172 First operand of the dot product.
173 b : {ndarray, sparse matrix}
174 Second operand of the dot product.
175 dense_output : bool, default=False
176 When False, ``a`` and ``b`` both being sparse will yield sparse output.
177 When True, output will always be a dense array.
178
179 Returns
180 -------
181 dot_product : {ndarray, sparse matrix}
182 Sparse if ``a`` and ``b`` are sparse and ``dense_output=False``.
183
184 Examples
185 --------
186 >>> from scipy.sparse import csr_array
187 >>> from sklearn.utils.extmath import safe_sparse_dot
188 >>> X = csr_array([[1, 2], [3, 4], [5, 6]])
189 >>> dot_product = safe_sparse_dot(X, X.T)
190 >>> dot_product.toarray()
191 array([[ 5, 11, 17],
192 [11, 25, 39],
193 [17, 39, 61]])
194 """
195 xp, _ = get_namespace(a, b)
196 if a.ndim > 2 or b.ndim > 2:
197 if sparse.issparse(a):
198 # sparse is always 2D. Implies b is 3D+
199 # [i, j] @ [k, ..., l, m, n] -> [i, k, ..., l, n]
200 b_ = np.rollaxis(b, -2)
201 b_2d = b_.reshape((b.shape[-2], -1))
202 ret = a @ b_2d
203 ret = ret.reshape(a.shape[0], *b_.shape[1:])
204 elif sparse.issparse(b):
205 # sparse is always 2D. Implies a is 3D+
206 # [k, ..., l, m] @ [i, j] -> [k, ..., l, j]
207 a_2d = a.reshape(-1, a.shape[-1])
208 ret = a_2d @ b
209 ret = ret.reshape(*a.shape[:-1], b.shape[1])
210 else:
211 # Alternative for `np.dot` when dealing with a or b having
212 # more than 2 dimensions, that works with the array api.
213 # If b is 1-dim then the last axis for b is taken otherwise
214 # if b is >= 2-dim then the second to last axis is taken.
215 b_axis = -1 if b.ndim == 1 else -2
216 ret = xp.tensordot(a, b, axes=[-1, b_axis])
217 elif (
218 dense_output
219 and a.ndim == 2
220 and b.ndim == 2
221 and (sparse.issparse(a) and a.format in ("csc", "csr"))
222 and (sparse.issparse(b) and b.format in ("csc", "csr"))
223 and a.dtype in (np.float32, np.float64)

Callers 15

transformMethod · 0.90
transformMethod · 0.90
transformMethod · 0.90
_countMethod · 0.90
_joint_log_likelihoodMethod · 0.90
_countMethod · 0.90
_joint_log_likelihoodMethod · 0.90
_countMethod · 0.90
_joint_log_likelihoodMethod · 0.90
test_safe_sparse_dot_2dFunction · 0.90
test_safe_sparse_dot_ndFunction · 0.90

Calls 2

get_namespaceFunction · 0.90
sparse_matmul_to_denseFunction · 0.90

Tested by 6

test_safe_sparse_dot_2dFunction · 0.72
test_safe_sparse_dot_ndFunction · 0.72
kfuncFunction · 0.72

Used in the wild real call sites across dependent graphs

searching dependent graphs…