MCPcopy
hub / github.com/ray-project/ray / test_train_test_split_stratified

Function test_train_test_split_stratified

python/ray/data/tests/test_split.py:779–807  ·  view source on GitHub ↗
(ray_start_regular_shared_2_cpus)

Source from the content-addressed store, hash-verified

777
778
779def test_train_test_split_stratified(ray_start_regular_shared_2_cpus):
780 # Test basic stratification with simple dataset
781 data = [
782 {"id": 0, "label": "A"},
783 {"id": 1, "label": "A"},
784 {"id": 2, "label": "B"},
785 {"id": 3, "label": "B"},
786 {"id": 4, "label": "C"},
787 {"id": 5, "label": "C"},
788 ]
789 ds = ray.data.from_items(data)
790
791 # Test stratified split
792 train, test = ds.train_test_split(test_size=0.5, stratify="label")
793
794 # Check that we have the right number of samples
795 assert train.count() == 3
796 assert test.count() == 3
797
798 # Check that class proportions are preserved
799 train_labels = [row["label"] for row in train.take()]
800 test_labels = [row["label"] for row in test.take()]
801
802 train_label_counts = {label: train_labels.count(label) for label in ["A", "B", "C"]}
803 test_label_counts = {label: test_labels.count(label) for label in ["A", "B", "C"]}
804
805 # Each class should have exactly 1 sample in each split
806 assert train_label_counts == {"A": 1, "B": 1, "C": 1}
807 assert test_label_counts == {"A": 1, "B": 1, "C": 1}
808
809
810def test_train_test_split_shuffle_stratify_error(ray_start_regular_shared_2_cpus):

Callers

nothing calls this directly

Calls 3

train_test_splitMethod · 0.80
countMethod · 0.45
takeMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…