(lhs, rhs, axes=2)
| 293 | |
| 294 | @derived_from(np) |
| 295 | def tensordot(lhs, rhs, axes=2): |
| 296 | if not isinstance(lhs, Array): |
| 297 | lhs = from_array(lhs) |
| 298 | if not isinstance(rhs, Array): |
| 299 | rhs = from_array(rhs) |
| 300 | |
| 301 | if isinstance(axes, Iterable): |
| 302 | left_axes, right_axes = axes |
| 303 | else: |
| 304 | left_axes = tuple(range(lhs.ndim - axes, lhs.ndim)) |
| 305 | right_axes = tuple(range(0, axes)) |
| 306 | if isinstance(left_axes, Integral): |
| 307 | left_axes = (left_axes,) |
| 308 | if isinstance(right_axes, Integral): |
| 309 | right_axes = (right_axes,) |
| 310 | if isinstance(left_axes, list): |
| 311 | left_axes = tuple(left_axes) |
| 312 | if isinstance(right_axes, list): |
| 313 | right_axes = tuple(right_axes) |
| 314 | is_sparse = _tensordot_is_sparse(lhs) or _tensordot_is_sparse(rhs) |
| 315 | if is_sparse and len(left_axes) == 1: |
| 316 | concatenate = True |
| 317 | else: |
| 318 | concatenate = False |
| 319 | dt = np.promote_types(lhs.dtype, rhs.dtype) |
| 320 | left_index = list(range(lhs.ndim)) |
| 321 | right_index = list(range(lhs.ndim, lhs.ndim + rhs.ndim)) |
| 322 | out_index = left_index + right_index |
| 323 | adjust_chunks = {} |
| 324 | for l, r in zip(left_axes, right_axes): |
| 325 | out_index.remove(right_index[r]) |
| 326 | right_index[r] = left_index[l] |
| 327 | if concatenate: |
| 328 | out_index.remove(left_index[l]) |
| 329 | else: |
| 330 | adjust_chunks[left_index[l]] = lambda c: 1 |
| 331 | intermediate = blockwise( |
| 332 | _tensordot, |
| 333 | out_index, |
| 334 | lhs, |
| 335 | left_index, |
| 336 | rhs, |
| 337 | right_index, |
| 338 | dtype=dt, |
| 339 | concatenate=concatenate, |
| 340 | adjust_chunks=adjust_chunks, |
| 341 | axes=(left_axes, right_axes), |
| 342 | is_sparse=is_sparse, |
| 343 | ) |
| 344 | if concatenate: |
| 345 | return intermediate |
| 346 | else: |
| 347 | left_axes = [ax if ax >= 0 else lhs.ndim + ax for ax in left_axes] |
| 348 | return intermediate.sum(axis=left_axes) |
| 349 | |
| 350 | |
| 351 | @derived_from(np, ua_args=["out"]) |
no test coverage detected
searching dependent graphs…