MCPcopy
hub / github.com/ray-project/ray / TensorArray

Class TensorArray

python/ray/data/_internal/tensor_extensions/pandas.py:647–1465  ·  view source on GitHub ↗

Pandas `ExtensionArray` representing a tensor column, i.e. a column consisting of ndarrays as elements. This extension supports tensors in which the elements have different shapes. However, each tensor element must be non-ragged, i.e. each tensor element must have a well-define

Source from the content-addressed store, hash-verified

645
646@PublicAPI(stability="beta")
647class TensorArray(
648 pd.api.extensions.ExtensionArray,
649 _TensorOpsMixin,
650 _TensorScalarCastMixin,
651):
652 """
653 Pandas `ExtensionArray` representing a tensor column, i.e. a column
654 consisting of ndarrays as elements.
655
656 This extension supports tensors in which the elements have different shapes.
657 However, each tensor element must be non-ragged, i.e. each tensor element must have
658 a well-defined, non-ragged shape.
659
660 Examples:
661 >>> # Create a DataFrame with a list of ndarrays as a column.
662 >>> import pandas as pd
663 >>> import numpy as np
664 >>> import ray
665 >>> from ray.data.extensions import TensorArray
666 >>> df = pd.DataFrame({
667 ... "one": [1, 2, 3],
668 ... "two": TensorArray(np.arange(24).reshape((3, 2, 2, 2)))})
669 >>> # Note that the column dtype is TensorDtype.
670 >>> df.dtypes # doctest: +SKIP
671 one int64
672 two TensorDtype(shape=(3, 2, 2, 2), dtype=int64)
673 dtype: object
674 >>> # Pandas is aware of this tensor column, and we can do the
675 >>> # typical DataFrame operations on this column.
676 >>> col = 2 * (df["two"] + 10)
677 >>> # The ndarrays underlying the tensor column will be manipulated,
678 >>> # but the column itself will continue to be a Pandas type.
679 >>> type(col) # doctest: +SKIP
680 pandas.core.series.Series
681 >>> col # doctest: +SKIP
682 0 [[[ 2 4]
683 [ 6 8]]
684 [[10 12]
685 [14 16]]]
686 1 [[[18 20]
687 [22 24]]
688 [[26 28]
689 [30 32]]]
690 2 [[[34 36]
691 [38 40]]
692 [[42 44]
693 [46 48]]]
694 Name: two, dtype: TensorDtype(shape=(3, 2, 2, 2), dtype=int64)
695 >>> # Once you do an aggregation on that column that returns a single
696 >>> # row's value, you get back our TensorArrayElement type.
697 >>> tensor = col.mean() # doctest: +SKIP
698 >>> type(tensor) # doctest: +SKIP
699 ray.data.extensions.tensor_extension.TensorArrayElement
700 >>> tensor # doctest: +SKIP
701 array([[[18., 20.],
702 [22., 24.]],
703 [[26., 28.],
704 [30., 32.]]])

Calls

no outgoing calls

Used in the wild real call sites across dependent graphs

searching dependent graphs…