Perform a random split on a batch: each row goes to train with probability (1 - test_proportion), or to test otherwise. This version ensures that the random choices are **stable per Ray task execution** by seeding the RNG with a combination of a user
(batch: pa.Table)
| 2854 | raise ValueError("hash_column is not supported for random split") |
| 2855 | |
| 2856 | def random_split(batch: pa.Table): |
| 2857 | """ |
| 2858 | Perform a random split on a batch: each row goes to train with probability (1 - test_proportion), |
| 2859 | or to test otherwise. |
| 2860 | |
| 2861 | This version ensures that the random choices are **stable per Ray task execution** by seeding |
| 2862 | the RNG with a combination of a user-specified seed and the Ray task ID. |
| 2863 | """ |
| 2864 | ctx = TaskContext.get_current() |
| 2865 | if "train_test_split_rng" in ctx.kwargs: |
| 2866 | rng = ctx.kwargs["train_test_split_rng"] |
| 2867 | elif seed is None: |
| 2868 | rng = np.random.default_rng([ctx.task_idx]) |
| 2869 | ctx.kwargs["train_test_split_rng"] = rng |
| 2870 | else: |
| 2871 | rng = np.random.default_rng([ctx.task_idx, seed]) |
| 2872 | ctx.kwargs["train_test_split_rng"] = rng |
| 2873 | |
| 2874 | # Draw Bernoulli samples: 1 = train, 0 = test |
| 2875 | is_train = rng.random(batch.num_rows) < (1 - test_size) |
| 2876 | return batch.append_column( |
| 2877 | _TRAIN_TEST_SPLIT_COLUMN, pa.array(is_train, type=pa.bool_()) |
| 2878 | ) |
| 2879 | |
| 2880 | def hash_split(batch: pa.Table) -> tuple[pa.Table, pa.Table]: |
| 2881 | def key_to_bucket(key: Any) -> int: |
nothing calls this directly
no test coverage detected