MCPcopy
hub / github.com/deepspeedai/DeepSpeedExamples / train

Function train

bing_bert/deepspeed_train.py:177–234  ·  view source on GitHub ↗
(args, index, model, optimizer, finetune=False)

Source from the content-addressed store, hash-verified

175 return dataset_picker, dataloaders, sum(datalengths)
176
177def train(args, index, model, optimizer, finetune=False):
178 global global_step
179 global global_data_samples
180 global last_global_step_from_restore
181
182 dataset_picker, dataloaders, total_length = get_train_dataset(args, index, finetune)
183 current_data_sample_count = global_data_samples
184 global_data_samples += total_length
185 config = args.config
186 logger = args.logger
187 print('total_length', total_length, 'global_data_samples', global_data_samples)
188
189 model.train()
190
191 epoch_step = 0
192 for step, dataset_type in enumerate(tqdm(dataset_picker, smoothing=1)):
193 try:
194 batch = next(dataloaders[dataset_type])
195 batch = tuple(t.to(args.device) for t in batch) # Move to GPU
196
197 # Calculate forward pass
198 loss = model.network(batch)
199 unscaled_loss = loss.item()
200 current_data_sample_count += (args.train_micro_batch_size_per_gpu * dist.get_world_size())
201
202 model.network.backward(loss)
203
204 if model.network.is_gradient_accumulation_boundary():
205 if args.fp16:
206 # modify learning rate with special warm up BERT uses
207 # if args.fp16 is False, BertAdam is used that handles this automatically
208 lr_this_step = update_learning_rate(config, global_step, optimizer)
209
210 report_step_metrics(args, lr_this_step, unscaled_loss, global_step, current_data_sample_count)
211
212 model.network.step()
213
214 report_lamb_coefficients(args, optimizer)
215 global_step += 1
216 epoch_step += 1
217 else:
218 # Call DeepSpeed engine step on micro steps
219 model.network.step()
220
221 except StopIteration:
222 continue
223
224 current_global_step = global_step - last_global_step_from_restore
225 if is_time_to_exit(args=args,
226 epoch_steps=epoch_step,
227 global_steps=current_global_step):
228 print(f'Warning: Early epoch termination due to max steps limit, epoch step ={epoch_step}, global step = {current_global_step}, epoch = {index+1}')
229 break
230
231 # Run Validation Loss
232 if not finetune and args.max_seq_length == 512:
233 logger.info(f"TRAIN BATCH SIZE: {args.train_micro_batch_size_per_gpu}")
234 pretrain_validation(args, index, model)

Callers 1

runFunction · 0.70

Calls 11

is_time_to_exitFunction · 0.90
get_train_datasetFunction · 0.85
update_learning_rateFunction · 0.85
report_step_metricsFunction · 0.85
report_lamb_coefficientsFunction · 0.85
pretrain_validationFunction · 0.85
trainMethod · 0.80
infoMethod · 0.80
toMethod · 0.45
backwardMethod · 0.45
stepMethod · 0.45

Tested by

no test coverage detected