(ray_start_regular_shared_2_cpus)
| 777 | |
| 778 | |
| 779 | def test_train_test_split_stratified(ray_start_regular_shared_2_cpus): |
| 780 | # Test basic stratification with simple dataset |
| 781 | data = [ |
| 782 | {"id": 0, "label": "A"}, |
| 783 | {"id": 1, "label": "A"}, |
| 784 | {"id": 2, "label": "B"}, |
| 785 | {"id": 3, "label": "B"}, |
| 786 | {"id": 4, "label": "C"}, |
| 787 | {"id": 5, "label": "C"}, |
| 788 | ] |
| 789 | ds = ray.data.from_items(data) |
| 790 | |
| 791 | # Test stratified split |
| 792 | train, test = ds.train_test_split(test_size=0.5, stratify="label") |
| 793 | |
| 794 | # Check that we have the right number of samples |
| 795 | assert train.count() == 3 |
| 796 | assert test.count() == 3 |
| 797 | |
| 798 | # Check that class proportions are preserved |
| 799 | train_labels = [row["label"] for row in train.take()] |
| 800 | test_labels = [row["label"] for row in test.take()] |
| 801 | |
| 802 | train_label_counts = {label: train_labels.count(label) for label in ["A", "B", "C"]} |
| 803 | test_label_counts = {label: test_labels.count(label) for label in ["A", "B", "C"]} |
| 804 | |
| 805 | # Each class should have exactly 1 sample in each split |
| 806 | assert train_label_counts == {"A": 1, "B": 1, "C": 1} |
| 807 | assert test_label_counts == {"A": 1, "B": 1, "C": 1} |
| 808 | |
| 809 | |
| 810 | def test_train_test_split_shuffle_stratify_error(ray_start_regular_shared_2_cpus): |
nothing calls this directly
no test coverage detected
searching dependent graphs…