MCPcopy
hub / github.com/ray-project/ray / test_train_test_split

Function test_train_test_split

python/ray/data/tests/test_split.py:744–776  ·  view source on GitHub ↗
(ray_start_regular_shared_2_cpus)

Source from the content-addressed store, hash-verified

742
743
744def test_train_test_split(ray_start_regular_shared_2_cpus):
745 ds = ray.data.range(8)
746
747 # float
748 train, test = ds.train_test_split(test_size=0.25)
749 assert extract_values("id", train.take()) == [0, 1, 2, 3, 4, 5]
750 assert extract_values("id", test.take()) == [6, 7]
751
752 # int
753 train, test = ds.train_test_split(test_size=2)
754 assert extract_values("id", train.take()) == [0, 1, 2, 3, 4, 5]
755 assert extract_values("id", test.take()) == [6, 7]
756
757 # shuffle
758 train, test = ds.train_test_split(test_size=0.25, shuffle=True, seed=1)
759 assert extract_values("id", train.take()) == [7, 4, 6, 0, 5, 2]
760 assert extract_values("id", test.take()) == [1, 3]
761
762 # error handling
763 with pytest.raises(TypeError):
764 ds.train_test_split(test_size=[1])
765
766 with pytest.raises(ValueError):
767 ds.train_test_split(test_size=-1)
768
769 with pytest.raises(ValueError):
770 ds.train_test_split(test_size=0)
771
772 with pytest.raises(ValueError):
773 ds.train_test_split(test_size=1.1)
774
775 with pytest.raises(ValueError):
776 ds.train_test_split(test_size=9)
777
778
779def test_train_test_split_stratified(ray_start_regular_shared_2_cpus):

Callers

nothing calls this directly

Calls 3

extract_valuesFunction · 0.90
train_test_splitMethod · 0.80
takeMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…