(ray_start_regular_shared_2_cpus)
| 742 | |
| 743 | |
| 744 | def test_train_test_split(ray_start_regular_shared_2_cpus): |
| 745 | ds = ray.data.range(8) |
| 746 | |
| 747 | # float |
| 748 | train, test = ds.train_test_split(test_size=0.25) |
| 749 | assert extract_values("id", train.take()) == [0, 1, 2, 3, 4, 5] |
| 750 | assert extract_values("id", test.take()) == [6, 7] |
| 751 | |
| 752 | # int |
| 753 | train, test = ds.train_test_split(test_size=2) |
| 754 | assert extract_values("id", train.take()) == [0, 1, 2, 3, 4, 5] |
| 755 | assert extract_values("id", test.take()) == [6, 7] |
| 756 | |
| 757 | # shuffle |
| 758 | train, test = ds.train_test_split(test_size=0.25, shuffle=True, seed=1) |
| 759 | assert extract_values("id", train.take()) == [7, 4, 6, 0, 5, 2] |
| 760 | assert extract_values("id", test.take()) == [1, 3] |
| 761 | |
| 762 | # error handling |
| 763 | with pytest.raises(TypeError): |
| 764 | ds.train_test_split(test_size=[1]) |
| 765 | |
| 766 | with pytest.raises(ValueError): |
| 767 | ds.train_test_split(test_size=-1) |
| 768 | |
| 769 | with pytest.raises(ValueError): |
| 770 | ds.train_test_split(test_size=0) |
| 771 | |
| 772 | with pytest.raises(ValueError): |
| 773 | ds.train_test_split(test_size=1.1) |
| 774 | |
| 775 | with pytest.raises(ValueError): |
| 776 | ds.train_test_split(test_size=9) |
| 777 | |
| 778 | |
| 779 | def test_train_test_split_stratified(ray_start_regular_shared_2_cpus): |
nothing calls this directly
no test coverage detected
searching dependent graphs…