Materialize and split the dataset into train and test subsets. Examples: >>> import ray >>> ds = ray.data.range(8) >>> train, test = ds.train_test_split(test_size=0.25) >>> train.take_batch() {'id': array([0, 1, 2, 3, 4, 5])}
(
self,
test_size: Union[int, float],
*,
shuffle: bool = False,
seed: Optional[int] = None,
stratify: Optional[str] = None,
)
| 2612 | @ConsumptionAPI |
| 2613 | @PublicAPI(api_group=SMJ_API_GROUP) |
| 2614 | def train_test_split( |
| 2615 | self, |
| 2616 | test_size: Union[int, float], |
| 2617 | *, |
| 2618 | shuffle: bool = False, |
| 2619 | seed: Optional[int] = None, |
| 2620 | stratify: Optional[str] = None, |
| 2621 | ) -> Tuple["MaterializedDataset", "MaterializedDataset"]: |
| 2622 | """Materialize and split the dataset into train and test subsets. |
| 2623 | |
| 2624 | Examples: |
| 2625 | |
| 2626 | >>> import ray |
| 2627 | >>> ds = ray.data.range(8) |
| 2628 | >>> train, test = ds.train_test_split(test_size=0.25) |
| 2629 | >>> train.take_batch() |
| 2630 | {'id': array([0, 1, 2, 3, 4, 5])} |
| 2631 | >>> test.take_batch() |
| 2632 | {'id': array([6, 7])} |
| 2633 | |
| 2634 | Args: |
| 2635 | test_size: If float, should be between 0.0 and 1.0 and represent the |
| 2636 | proportion of the dataset to include in the test split. If int, |
| 2637 | represents the absolute number of test samples. The train split |
| 2638 | always complements the test split. |
| 2639 | shuffle: Whether or not to globally shuffle the dataset before splitting. |
| 2640 | Defaults to ``False``. This may be a very expensive operation with a |
| 2641 | large dataset. |
| 2642 | seed: Fix the random seed to use for shuffle, otherwise one is chosen |
| 2643 | based on system randomness. Ignored if ``shuffle=False``. |
| 2644 | stratify: Optional column name to use for stratified sampling. If provided, |
| 2645 | the splits will maintain the same proportions of each class in the |
| 2646 | stratify column across both train and test sets. |
| 2647 | |
| 2648 | Returns: |
| 2649 | Train and test subsets as two ``MaterializedDatasets``. |
| 2650 | |
| 2651 | .. seealso:: |
| 2652 | |
| 2653 | :meth:`Dataset.split_proportionately` |
| 2654 | """ |
| 2655 | ds = self |
| 2656 | |
| 2657 | if shuffle: |
| 2658 | ds = ds.random_shuffle(seed=seed) |
| 2659 | |
| 2660 | if not isinstance(test_size, (int, float)): |
| 2661 | raise TypeError(f"`test_size` must be int or float got {type(test_size)}.") |
| 2662 | |
| 2663 | # Validate that shuffle=True and stratify are not both specified |
| 2664 | if shuffle and stratify is not None: |
| 2665 | raise ValueError( |
| 2666 | "Cannot specify both 'shuffle=True' and 'stratify' parameters. " |
| 2667 | "Stratified splitting maintains class proportions and is incompatible with shuffling." |
| 2668 | ) |
| 2669 | |
| 2670 | # Handle stratified splitting |
| 2671 | if stratify is not None: |