(tokenizer, args)
| 192 | |
| 193 | |
| 194 | def create_datasets(tokenizer, args): |
| 195 | dataset = load_dataset( |
| 196 | args.dataset_name, |
| 197 | data_dir=args.subset, |
| 198 | split=args.split, |
| 199 | use_auth_token=True, |
| 200 | num_proc=args.num_workers if not args.streaming else None, |
| 201 | streaming=args.streaming, |
| 202 | ) |
| 203 | if args.streaming: |
| 204 | print("Loading the dataset in streaming mode") |
| 205 | valid_data = dataset.take(args.size_valid_set) |
| 206 | train_data = dataset.skip(args.size_valid_set) |
| 207 | train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed) |
| 208 | else: |
| 209 | train_data = dataset["train"] |
| 210 | valid_data = dataset["test"] |
| 211 | print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}") |
| 212 | |
| 213 | chars_per_token = chars_token_ratio(train_data, tokenizer, args.input_column_name, args.output_column_name) |
| 214 | print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}") |
| 215 | |
| 216 | train_dataset = ConstantLengthDataset( |
| 217 | tokenizer, |
| 218 | train_data, |
| 219 | infinite=True, |
| 220 | seq_length=args.seq_length, |
| 221 | chars_per_token=chars_per_token, |
| 222 | input_column_name=args.input_column_name, |
| 223 | output_column_name=args.output_column_name |
| 224 | ) |
| 225 | valid_dataset = ConstantLengthDataset( |
| 226 | tokenizer, |
| 227 | valid_data, |
| 228 | infinite=False, |
| 229 | seq_length=args.seq_length, |
| 230 | chars_per_token=chars_per_token, |
| 231 | input_column_name=args.input_column_name, |
| 232 | output_column_name=args.output_column_name |
| 233 | ) |
| 234 | return train_dataset, valid_dataset |
| 235 | |
| 236 | |
| 237 | def run_training(args, train_data, val_data): |
no test coverage detected