MCPcopy
hub / github.com/ray-project/ray / test_tensors_basic

Function test_tensors_basic

python/ray/data/tests/test_tensor.py:47–224  ·  view source on GitHub ↗
(ray_start_regular_shared, tensor_format_context)

Source from the content-addressed store, hash-verified

45
46
47def test_tensors_basic(ray_start_regular_shared, tensor_format_context):
48
49 # Determine expected tensor type based on format
50 expected_type = create_arrow_fixed_shape_tensor_type(shape=(3, 5), dtype=pa.int64())
51
52 # Create directly.
53 tensor_shape = (3, 5)
54 ds = ray.data.range_tensor(6, shape=tensor_shape, override_num_blocks=6)
55 assert ds.count() == 6
56 assert ds.schema() == Schema(pa.schema([("data", expected_type)]))
57 # The actual size is slightly larger due to metadata.
58 # We add 6 (one per tensor) offset values of 8 bytes each to account for the
59 # in-memory representation of the PyArrow LargeList type
60 assert math.isclose(ds.size_bytes(), 5 * 3 * 6 * 8 + 6 * 8, rel_tol=0.1)
61
62 # Test row iterator yields tensors.
63 for tensor in ds.iter_rows():
64 tensor = tensor["data"]
65 assert isinstance(tensor, np.ndarray)
66 assert tensor.shape == tensor_shape
67
68 # Test batch iterator yields tensors.
69 for tensor in ds.iter_batches(batch_size=2):
70 tensor = tensor["data"]
71 assert isinstance(tensor, np.ndarray)
72 assert tensor.shape == (2,) + tensor_shape
73
74 # Native format.
75 def np_mapper(arr):
76 if "data" in arr:
77 arr = arr["data"]
78 else:
79 arr = arr["id"]
80 assert isinstance(arr, np.ndarray)
81 return {"data": arr + 1}
82
83 res = ray.data.range_tensor(2, shape=(2, 2)).map(np_mapper).take()
84 np.testing.assert_equal(
85 extract_values("data", res), [np.ones((2, 2)), 2 * np.ones((2, 2))]
86 )
87
88 # Explicit NumPy format.
89 res = (
90 ray.data.range_tensor(2, shape=(2, 2))
91 .map_batches(np_mapper, batch_format="numpy")
92 .take()
93 )
94 np.testing.assert_equal(
95 extract_values("data", res), [np.ones((2, 2)), 2 * np.ones((2, 2))]
96 )
97
98 # Pandas conversion.
99 def pd_mapper(df):
100 assert isinstance(df, pd.DataFrame)
101 return df + 2
102
103 res = ray.data.range_tensor(2).map_batches(pd_mapper, batch_format="pandas").take()
104 np.testing.assert_equal(extract_values("data", res), [np.array([2]), np.array([3])])

Callers

nothing calls this directly

Calls 15

SchemaClass · 0.90
extract_valuesFunction · 0.90
TensorArrayClass · 0.90
listFunction · 0.85
map_batchesMethod · 0.80
tableMethod · 0.80
from_arrowMethod · 0.80
rangeFunction · 0.70
countMethod · 0.45
schemaMethod · 0.45
size_bytesMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…