MCPcopy
hub / github.com/RasaHQ/rasa / split

Method split

rasa/utils/tensorflow/model_data.py:588–662  ·  view source on GitHub ↗

Create random hold out test set using stratified split. Args: number_of_test_examples: Number of test examples. random_seed: Random seed. Returns: A tuple of train and test RasaModelData.

(
        self, number_of_test_examples: int, random_seed: int
    )

Source from the content-addressed store, hash-verified

586 return self.sparse_feature_sizes
587
588 def split(
589 self, number_of_test_examples: int, random_seed: int
590 ) -> Tuple["RasaModelData", "RasaModelData"]:
591 """Create random hold out test set using stratified split.
592
593 Args:
594 number_of_test_examples: Number of test examples.
595 random_seed: Random seed.
596
597 Returns:
598 A tuple of train and test RasaModelData.
599 """
600 self._check_label_key()
601
602 if self.label_key is None or self.label_sub_key is None:
603 # randomly split data as no label key is set
604 multi_values = [
605 v
606 for attribute_data in self.data.values()
607 for data in attribute_data.values()
608 for v in data
609 ]
610 solo_values: List[Any] = [
611 []
612 for attribute_data in self.data.values()
613 for data in attribute_data.values()
614 for _ in data
615 ]
616 stratify = None
617 else:
618 # make sure that examples for each label value are in both split sets
619 label_ids = self._create_label_ids(
620 self.data[self.label_key][self.label_sub_key][0]
621 )
622 label_counts: Dict[int, int] = dict(
623 zip(
624 *np.unique(
625 label_ids,
626 return_counts=True,
627 axis=0,
628 )
629 )
630 )
631
632 self._check_train_test_sizes(number_of_test_examples, label_counts)
633
634 counts = np.array([label_counts[label] for label in label_ids])
635 # we perform stratified train test split,
636 # which insures every label is present in the train and test data
637 # this operation can be performed only for labels
638 # that contain several data points
639 multi_values = [
640 f[counts > 1].view(FeatureArray)
641 for attribute_data in self.data.values()
642 for features in attribute_data.values()
643 for f in features
644 ]
645 # collect data points that are unique for their label

Callers 15

_default_context_fieldsFunction · 0.80
create_data_generatorsFunction · 0.80
get_attributeMethod · 0.80
_parse_gpu_configFunction · 0.80
_multi_seq_fnFunction · 0.80
compare_models_in_dirFunction · 0.80
find_routeFunction · 0.80
send_text_messageMethod · 0.80
send_text_messageMethod · 0.80

Calls 5

_check_label_keyMethod · 0.95
_create_label_idsMethod · 0.95
valuesMethod · 0.80

Tested by 15

compare_models_in_dirFunction · 0.64
generate_foldsFunction · 0.64
test_log_failed_storiesFunction · 0.64
test_train_val_splitFunction · 0.64
test_message_orderFunction · 0.64
test_message_orderFunction · 0.64
test_train_helpFunction · 0.64
test_train_nlu_helpFunction · 0.64
test_train_core_helpFunction · 0.64