MCPcopy
hub / github.com/dask/dask / tensordot

Function tensordot

dask/array/routines.py:295–348  ·  view source on GitHub ↗
(lhs, rhs, axes=2)

Source from the content-addressed store, hash-verified

293
294@derived_from(np)
295def tensordot(lhs, rhs, axes=2):
296 if not isinstance(lhs, Array):
297 lhs = from_array(lhs)
298 if not isinstance(rhs, Array):
299 rhs = from_array(rhs)
300
301 if isinstance(axes, Iterable):
302 left_axes, right_axes = axes
303 else:
304 left_axes = tuple(range(lhs.ndim - axes, lhs.ndim))
305 right_axes = tuple(range(0, axes))
306 if isinstance(left_axes, Integral):
307 left_axes = (left_axes,)
308 if isinstance(right_axes, Integral):
309 right_axes = (right_axes,)
310 if isinstance(left_axes, list):
311 left_axes = tuple(left_axes)
312 if isinstance(right_axes, list):
313 right_axes = tuple(right_axes)
314 is_sparse = _tensordot_is_sparse(lhs) or _tensordot_is_sparse(rhs)
315 if is_sparse and len(left_axes) == 1:
316 concatenate = True
317 else:
318 concatenate = False
319 dt = np.promote_types(lhs.dtype, rhs.dtype)
320 left_index = list(range(lhs.ndim))
321 right_index = list(range(lhs.ndim, lhs.ndim + rhs.ndim))
322 out_index = left_index + right_index
323 adjust_chunks = {}
324 for l, r in zip(left_axes, right_axes):
325 out_index.remove(right_index[r])
326 right_index[r] = left_index[l]
327 if concatenate:
328 out_index.remove(left_index[l])
329 else:
330 adjust_chunks[left_index[l]] = lambda c: 1
331 intermediate = blockwise(
332 _tensordot,
333 out_index,
334 lhs,
335 left_index,
336 rhs,
337 right_index,
338 dtype=dt,
339 concatenate=concatenate,
340 adjust_chunks=adjust_chunks,
341 axes=(left_axes, right_axes),
342 is_sparse=is_sparse,
343 )
344 if concatenate:
345 return intermediate
346 else:
347 left_axes = [ax if ax >= 0 else lhs.ndim + ax for ax in left_axes]
348 return intermediate.sum(axis=left_axes)
349
350
351@derived_from(np, ua_args=["out"])

Callers 3

dotMethod · 0.90
_tensordotFunction · 0.85
dotFunction · 0.85

Calls 5

from_arrayFunction · 0.90
_tensordot_is_sparseFunction · 0.85
removeMethod · 0.80
blockwiseFunction · 0.70
sumMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…