Load train/val/test dataset from shuffled TFRecords
(args)
| 77 | |
| 78 | |
| 79 | def make_tfrecord_loaders(args): |
| 80 | """Load train/val/test dataset from shuffled TFRecords""" |
| 81 | |
| 82 | import data_utils.tf_dl |
| 83 | data_set_args = {'batch_size': args.batch_size, |
| 84 | 'max_seq_len': args.seq_length, |
| 85 | 'max_preds_per_seq': args.max_preds_per_seq, |
| 86 | 'train': True, |
| 87 | 'num_workers': max(args.num_workers, 1), |
| 88 | 'seed': args.seed + args.rank + 1, |
| 89 | 'threaded_dl': args.num_workers > 0 |
| 90 | } |
| 91 | train = data_utils.tf_dl.TFRecordDataLoader(args.train_data, |
| 92 | **data_set_args) |
| 93 | data_set_args['train'] = False |
| 94 | if args.eval_seq_length is not None: |
| 95 | data_set_args['max_seq_len'] = args.eval_seq_length |
| 96 | if args.eval_max_preds_per_seq is not None: |
| 97 | data_set_args['max_preds_per_seq'] = args.eval_max_preds_per_seq |
| 98 | valid = None |
| 99 | if args.valid_data is not None: |
| 100 | valid = data_utils.tf_dl.TFRecordDataLoader(args.valid_data, |
| 101 | **data_set_args) |
| 102 | test = None |
| 103 | if args.test_data is not None: |
| 104 | test = data_utils.tf_dl.TFRecordDataLoader(args.test_data, |
| 105 | **data_set_args) |
| 106 | tokenizer = data_utils.make_tokenizer(args.tokenizer_type, |
| 107 | train, |
| 108 | args.tokenizer_path, |
| 109 | args.vocab_size, |
| 110 | args.tokenizer_model_type, |
| 111 | cache_dir=args.cache_dir) |
| 112 | |
| 113 | return (train, valid, test), tokenizer |
| 114 | |
| 115 | |
| 116 | def make_loaders(args): |