Create random hold out test set using stratified split. Args: number_of_test_examples: Number of test examples. random_seed: Random seed. Returns: A tuple of train and test RasaModelData.
(
self, number_of_test_examples: int, random_seed: int
)
| 586 | return self.sparse_feature_sizes |
| 587 | |
| 588 | def split( |
| 589 | self, number_of_test_examples: int, random_seed: int |
| 590 | ) -> Tuple["RasaModelData", "RasaModelData"]: |
| 591 | """Create random hold out test set using stratified split. |
| 592 | |
| 593 | Args: |
| 594 | number_of_test_examples: Number of test examples. |
| 595 | random_seed: Random seed. |
| 596 | |
| 597 | Returns: |
| 598 | A tuple of train and test RasaModelData. |
| 599 | """ |
| 600 | self._check_label_key() |
| 601 | |
| 602 | if self.label_key is None or self.label_sub_key is None: |
| 603 | # randomly split data as no label key is set |
| 604 | multi_values = [ |
| 605 | v |
| 606 | for attribute_data in self.data.values() |
| 607 | for data in attribute_data.values() |
| 608 | for v in data |
| 609 | ] |
| 610 | solo_values: List[Any] = [ |
| 611 | [] |
| 612 | for attribute_data in self.data.values() |
| 613 | for data in attribute_data.values() |
| 614 | for _ in data |
| 615 | ] |
| 616 | stratify = None |
| 617 | else: |
| 618 | # make sure that examples for each label value are in both split sets |
| 619 | label_ids = self._create_label_ids( |
| 620 | self.data[self.label_key][self.label_sub_key][0] |
| 621 | ) |
| 622 | label_counts: Dict[int, int] = dict( |
| 623 | zip( |
| 624 | *np.unique( |
| 625 | label_ids, |
| 626 | return_counts=True, |
| 627 | axis=0, |
| 628 | ) |
| 629 | ) |
| 630 | ) |
| 631 | |
| 632 | self._check_train_test_sizes(number_of_test_examples, label_counts) |
| 633 | |
| 634 | counts = np.array([label_counts[label] for label in label_ids]) |
| 635 | # we perform stratified train test split, |
| 636 | # which insures every label is present in the train and test data |
| 637 | # this operation can be performed only for labels |
| 638 | # that contain several data points |
| 639 | multi_values = [ |
| 640 | f[counts > 1].view(FeatureArray) |
| 641 | for attribute_data in self.data.values() |
| 642 | for features in attribute_data.values() |
| 643 | for f in features |
| 644 | ] |
| 645 | # collect data points that are unique for their label |