MCPcopy
hub / github.com/deepspeedai/DeepSpeedExamples / get_train_dataset

Function get_train_dataset

bing_bert/deepspeed_train.py:113–175  ·  view source on GitHub ↗
(args, index, finetune=False, shuffle=True)

Source from the content-addressed store, hash-verified

111 return (not args.no_cuda and dist.get_rank() == 0) or (args.no_cuda and args.local_rank == -1)
112
113def get_train_dataset(args, index, finetune=False, shuffle=True):
114 assert not finetune, "finetune not supported"
115 i = 0
116 dataloaders = {}
117 datalengths = []
118 batchs_per_dataset = []
119 batch_mapping = {}
120
121 config = args.config
122 dataset_paths = config["data"]["datasets"]
123 dataset_flags = config["data"]["flags"]
124
125 # Pretraining dataset
126 if dataset_flags.get("pretrain_dataset", False):
127 pretrain_type = dataset_flags.get("pretrain_type")
128
129 if pretrain_type == "wiki_bc":
130 # Load Wiki Dataset
131 wiki_pretrain_dataset = PreTrainingDataset(
132 args.tokenizer,
133 os.path.join(args.data_path_prefix, dataset_paths['wiki_pretrain_dataset']),
134 args.logger,
135 args.max_seq_length,
136 index,
137 PretrainDataType.NUMPY,
138 args.max_predictions_per_seq)
139 datalengths.append(len(wiki_pretrain_dataset))
140 dataloaders[i] = get_dataloader(args, wiki_pretrain_dataset)
141 batch_mapping[i] = PretrainBatch
142 batchs_per_dataset.append(
143 get_effective_batch(args, len(wiki_pretrain_dataset)))
144 i += 1
145
146 bc_pretrain_dataset = PreTrainingDataset(
147 args.tokenizer,
148 os.path.join(args.data_path_prefix, dataset_paths['bc_pretrain_dataset']),
149 args.logger,
150 args.max_seq_length,
151 index,
152 PretrainDataType.NUMPY,
153 args.max_predictions_per_seq
154 )
155 datalengths.append(len(bc_pretrain_dataset))
156 dataloaders[i] = get_dataloader(args, bc_pretrain_dataset)
157 batch_mapping[i] = PretrainBatch
158 batchs_per_dataset.append(
159 get_effective_batch(args, len(bc_pretrain_dataset)))
160 i += 1
161
162 dataset_batches = []
163 for i, batch_count in enumerate(batchs_per_dataset):
164 dataset_batches.extend([i] * batch_count)
165
166 # shuffle
167 if shuffle:
168 random.shuffle(dataset_batches)
169
170 dataset_picker = []

Callers 1

trainFunction · 0.85

Calls 5

PreTrainingDatasetClass · 0.90
get_dataloaderFunction · 0.85
get_effective_batchFunction · 0.85
appendMethod · 0.80
extendMethod · 0.80

Tested by

no test coverage detected