Dot product that handle the sparse matrix case correctly. Parameters ---------- a : {ndarray, sparse matrix} First operand of the dot product. b : {ndarray, sparse matrix} Second operand of the dot product. dense_output : bool, default=False When False, `
(a, b, *, dense_output=False)
| 164 | |
| 165 | |
| 166 | def safe_sparse_dot(a, b, *, dense_output=False): |
| 167 | """Dot product that handle the sparse matrix case correctly. |
| 168 | |
| 169 | Parameters |
| 170 | ---------- |
| 171 | a : {ndarray, sparse matrix} |
| 172 | First operand of the dot product. |
| 173 | b : {ndarray, sparse matrix} |
| 174 | Second operand of the dot product. |
| 175 | dense_output : bool, default=False |
| 176 | When False, ``a`` and ``b`` both being sparse will yield sparse output. |
| 177 | When True, output will always be a dense array. |
| 178 | |
| 179 | Returns |
| 180 | ------- |
| 181 | dot_product : {ndarray, sparse matrix} |
| 182 | Sparse if ``a`` and ``b`` are sparse and ``dense_output=False``. |
| 183 | |
| 184 | Examples |
| 185 | -------- |
| 186 | >>> from scipy.sparse import csr_array |
| 187 | >>> from sklearn.utils.extmath import safe_sparse_dot |
| 188 | >>> X = csr_array([[1, 2], [3, 4], [5, 6]]) |
| 189 | >>> dot_product = safe_sparse_dot(X, X.T) |
| 190 | >>> dot_product.toarray() |
| 191 | array([[ 5, 11, 17], |
| 192 | [11, 25, 39], |
| 193 | [17, 39, 61]]) |
| 194 | """ |
| 195 | xp, _ = get_namespace(a, b) |
| 196 | if a.ndim > 2 or b.ndim > 2: |
| 197 | if sparse.issparse(a): |
| 198 | # sparse is always 2D. Implies b is 3D+ |
| 199 | # [i, j] @ [k, ..., l, m, n] -> [i, k, ..., l, n] |
| 200 | b_ = np.rollaxis(b, -2) |
| 201 | b_2d = b_.reshape((b.shape[-2], -1)) |
| 202 | ret = a @ b_2d |
| 203 | ret = ret.reshape(a.shape[0], *b_.shape[1:]) |
| 204 | elif sparse.issparse(b): |
| 205 | # sparse is always 2D. Implies a is 3D+ |
| 206 | # [k, ..., l, m] @ [i, j] -> [k, ..., l, j] |
| 207 | a_2d = a.reshape(-1, a.shape[-1]) |
| 208 | ret = a_2d @ b |
| 209 | ret = ret.reshape(*a.shape[:-1], b.shape[1]) |
| 210 | else: |
| 211 | # Alternative for `np.dot` when dealing with a or b having |
| 212 | # more than 2 dimensions, that works with the array api. |
| 213 | # If b is 1-dim then the last axis for b is taken otherwise |
| 214 | # if b is >= 2-dim then the second to last axis is taken. |
| 215 | b_axis = -1 if b.ndim == 1 else -2 |
| 216 | ret = xp.tensordot(a, b, axes=[-1, b_axis]) |
| 217 | elif ( |
| 218 | dense_output |
| 219 | and a.ndim == 2 |
| 220 | and b.ndim == 2 |
| 221 | and (sparse.issparse(a) and a.format in ("csc", "csr")) |
| 222 | and (sparse.issparse(b) and b.format in ("csc", "csr")) |
| 223 | and a.dtype in (np.float32, np.float64) |
searching dependent graphs…