makes training/val/test
(args)
| 114 | |
| 115 | |
| 116 | def make_loaders(args): |
| 117 | """makes training/val/test""" |
| 118 | |
| 119 | if args.use_tfrecords: |
| 120 | return make_tfrecord_loaders(args) |
| 121 | world_size = torch.distributed.get_world_size( |
| 122 | group=mpu.get_data_parallel_group()) |
| 123 | batch_size = args.batch_size * world_size |
| 124 | eval_batch_size = batch_size |
| 125 | if args.eval_batch_size is not None: |
| 126 | eval_batch_size = args.eval_batch_size * world_size |
| 127 | seq_length = args.seq_length |
| 128 | if seq_length < 0: |
| 129 | seq_length = seq_length * world_size |
| 130 | eval_seq_length = args.eval_seq_length |
| 131 | if eval_seq_length is not None and eval_seq_length < 0: |
| 132 | eval_seq_length = eval_seq_length * world_size |
| 133 | split = get_split(args) |
| 134 | data_set_args = { |
| 135 | 'path': args.train_data, |
| 136 | 'seq_length': seq_length, |
| 137 | 'lazy': args.lazy_loader, |
| 138 | 'delim': args.delim, |
| 139 | 'text_key': args.text_key, |
| 140 | 'label_key': 'label', |
| 141 | 'non_binary_cols': None, |
| 142 | 'ds_type': args.data_set_type, |
| 143 | 'split': split, |
| 144 | 'loose': args.loose_json, |
| 145 | 'tokenizer_type': args.tokenizer_type, |
| 146 | 'tokenizer_model_path': args.tokenizer_path, |
| 147 | 'vocab_size': args.vocab_size, |
| 148 | 'model_type': args.tokenizer_model_type, |
| 149 | 'cache_dir': args.cache_dir, |
| 150 | 'max_preds_per_seq': args.max_preds_per_seq, |
| 151 | 'presplit_sentences': args.presplit_sentences} |
| 152 | |
| 153 | eval_set_args = copy.copy(data_set_args) |
| 154 | eval_set_args['split'] = [1.] |
| 155 | # if optional eval args were set then replace their |
| 156 | # equivalent values in the arg dict |
| 157 | if eval_seq_length: |
| 158 | eval_set_args['seq_length'] = eval_seq_length |
| 159 | if args.eval_max_preds_per_seq: |
| 160 | eval_set_args['max_preds_per_seq'] = args.eval_max_preds_per_seq |
| 161 | if args.eval_text_key is not None: |
| 162 | eval_set_args['text_key'] = args.eval_text_key |
| 163 | |
| 164 | # make datasets splits and tokenizer |
| 165 | train = None |
| 166 | valid = None |
| 167 | test = None |
| 168 | |
| 169 | if args.train_data is not None: |
| 170 | train, tokenizer = data_utils.make_dataset(**data_set_args) |
| 171 | if data_utils.should_split(split): |
| 172 | train, valid, test = train |
| 173 | eval_set_args['tokenizer'] = tokenizer |
no test coverage detected