Loads one or more datasets with varying training set proportions. Args: data_config (`DataArguments` or `dict`): Dataset configuration and split proportions. splits (`List[str]`, *optional*, defaults to `['train', 'test']`): Dataset splits to load an
(
data_config: DataArguments | dict,
splits: Optional[List[str]] = None,
configs: Optional[List[str]] = None,
columns_to_keep: Optional[List[str]] = None,
shuffle: bool = True,
)
| 123 | |
| 124 | |
| 125 | def get_datasets( |
| 126 | data_config: DataArguments | dict, |
| 127 | splits: Optional[List[str]] = None, |
| 128 | configs: Optional[List[str]] = None, |
| 129 | columns_to_keep: Optional[List[str]] = None, |
| 130 | shuffle: bool = True, |
| 131 | ) -> DatasetDict: |
| 132 | """ |
| 133 | Loads one or more datasets with varying training set proportions. |
| 134 | |
| 135 | Args: |
| 136 | data_config (`DataArguments` or `dict`): |
| 137 | Dataset configuration and split proportions. |
| 138 | splits (`List[str]`, *optional*, defaults to `['train', 'test']`): |
| 139 | Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix. |
| 140 | configs (Optional[List[str]], *optional*, defaults to `None`): |
| 141 | List of dataset config names. If given must be the same length as 'data_config' keys. |
| 142 | columns_to_keep (Optional[List[str]], *optional*, defaults to `None`): |
| 143 | Column names to keep in the dataset. Useful in the datamixer to avoid schema conflicts, |
| 144 | and for cpt this should be (at least) the text column. |
| 145 | shuffle (`bool`, *optional*, defaults to `True`): |
| 146 | Whether to shuffle the training and testing/validation data. |
| 147 | |
| 148 | Returns |
| 149 | [`DatasetDict`]: The dataset dictionary containing the loaded datasets. |
| 150 | """ |
| 151 | if type(data_config) is DataArguments: |
| 152 | # Structure of the config to read the datasets and their mix |
| 153 | # datasets_mixer: |
| 154 | # - 'dataset1': 0.5 |
| 155 | # - 'dataset2': 0.3 |
| 156 | # - 'dataset3': 0.2 |
| 157 | dataset_mixer = data_config.dataset_mixer |
| 158 | elif isinstance(data_config, dict): |
| 159 | # Structure of the input is: |
| 160 | # dataset_mixer = { |
| 161 | # "dataset1": 0.5, |
| 162 | # "dataset1": 0.3, |
| 163 | # "dataset1": 0.2, |
| 164 | # } |
| 165 | dataset_mixer = data_config |
| 166 | else: |
| 167 | raise ValueError(f"Data config {data_config} not recognized.") |
| 168 | |
| 169 | raw_datasets = mix_datasets( |
| 170 | dataset_mixer, |
| 171 | splits=splits, |
| 172 | configs=configs, |
| 173 | columns_to_keep=columns_to_keep, |
| 174 | shuffle=shuffle, |
| 175 | ) |
| 176 | return raw_datasets |
| 177 | |
| 178 | |
| 179 | def mix_datasets( |