makes training/val/test
(args, tokenizer)
| 246 | |
| 247 | |
| 248 | def make_loaders(args, tokenizer): |
| 249 | """makes training/val/test""" |
| 250 | |
| 251 | if args.use_tfrecords: |
| 252 | return make_tfrecord_loaders(args) |
| 253 | world_size = torch.distributed.get_world_size(group=mpu.get_data_parallel_group()) |
| 254 | if args.loader_scatter is not None: |
| 255 | assert world_size % args.loader_scatter == 0 |
| 256 | batch_size = args.batch_size * world_size |
| 257 | eval_batch_size = batch_size |
| 258 | if args.eval_batch_size is not None: |
| 259 | eval_batch_size = args.eval_batch_size * world_size |
| 260 | seq_length = args.seq_length |
| 261 | if seq_length < 0: |
| 262 | seq_length = seq_length * world_size |
| 263 | eval_seq_length = args.eval_seq_length |
| 264 | if eval_seq_length is not None and eval_seq_length < 0: |
| 265 | eval_seq_length = eval_seq_length * world_size |
| 266 | split = get_split(args) |
| 267 | data_set_args = { |
| 268 | 'path': args.train_data, |
| 269 | 'seq_length': seq_length, |
| 270 | 'mem_length': args.mem_length, |
| 271 | 'delim': args.delim, |
| 272 | 'text_key': args.text_key, |
| 273 | 'label_key': 'label', |
| 274 | 'ds_type': args.data_set_type, |
| 275 | 'split': split, |
| 276 | 'loose': args.loose_json, |
| 277 | 'max_preds_per_seq': args.max_preds_per_seq, |
| 278 | 'presplit_sentences': args.presplit_sentences, |
| 279 | 'sample_one_document': args.sample_one_document, |
| 280 | 'filter_english': args.filter_english, |
| 281 | 'pre_tokenize': not args.no_pre_tokenize, |
| 282 | 'tokenizer': tokenizer, |
| 283 | 'save_splits': args.save_splits, |
| 284 | 'load_splits': args.load_splits, |
| 285 | 'save_test_data': args.save_test_data, |
| 286 | 'no_lazy_loader': args.no_lazy_loader, |
| 287 | 'loader_scatter': args.loader_scatter, |
| 288 | 'data_parallel_rank': mpu.get_data_parallel_rank(), |
| 289 | "non_sentence_start": args.non_sentence_start, |
| 290 | "half_lazy_loader": args.half_lazy_loader |
| 291 | } |
| 292 | |
| 293 | eval_set_args = copy.copy(data_set_args) |
| 294 | eval_set_args['split'] = [1.] |
| 295 | # if optional eval args were set then replace their |
| 296 | # equivalent values in the arg dict |
| 297 | if eval_seq_length: |
| 298 | eval_set_args['seq_length'] = eval_seq_length |
| 299 | if args.eval_max_preds_per_seq: |
| 300 | eval_set_args['max_preds_per_seq'] = args.eval_max_preds_per_seq |
| 301 | if args.eval_text_key is not None: |
| 302 | eval_set_args['text_key'] = args.eval_text_key |
| 303 | |
| 304 | # make datasets splits and tokenizer |
| 305 | train, valid, test = None, None, None |
no test coverage detected