MCPcopy
hub / github.com/deepspeedai/DeepSpeedExamples / train

Function train

Megatron-LM/pretrain_bert.py:301–375  ·  view source on GitHub ↗

Train the model.

(model, optimizer, lr_scheduler,
          train_data_iterator, val_data_iterator, timers, args)

Source from the content-addressed store, hash-verified

299
300
301def train(model, optimizer, lr_scheduler,
302 train_data_iterator, val_data_iterator, timers, args):
303 """Train the model."""
304
305 # Turn on training mode which enables dropout.
306 model.train()
307
308 # Tracking loss.
309 total_lm_loss = 0.0
310 total_nsp_loss = 0.0
311
312 # Iterations.
313 iteration = args.iteration
314 skipped_iters = 0
315
316 timers('interval time').start()
317 report_memory_flag = True
318 while iteration < args.train_iters:
319
320 lm_loss, nsp_loss, skipped_iter = train_step(train_data_iterator,
321 model,
322 optimizer,
323 lr_scheduler,
324 args, timers)
325 skipped_iters += skipped_iter
326 iteration += 1
327
328 # Update losses.
329 total_lm_loss += lm_loss.data.detach().float()
330 total_nsp_loss += nsp_loss.data.detach().float()
331
332 # Logging.
333 if iteration % args.log_interval == 0:
334 learning_rate = optimizer.param_groups[0]['lr']
335 avg_nsp_loss = total_nsp_loss.item() / args.log_interval
336 avg_lm_loss = total_lm_loss.item() / args.log_interval
337 elapsed_time = timers('interval time').elapsed()
338 log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
339 args.train_iters)
340 log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
341 elapsed_time * 1000.0 / args.log_interval)
342 log_string += ' learning rate {:.3E} |'.format(learning_rate)
343 log_string += ' lm loss {:.6E} |'.format(avg_lm_loss)
344 log_string += ' nsp loss {:.6E} |'.format(avg_nsp_loss)
345 if args.fp16:
346 log_string += ' loss scale {:.1f} |'.format(
347 optimizer.loss_scale)
348 print_rank_0(log_string)
349 total_nsp_loss = 0.0
350 total_lm_loss = 0.0
351 if report_memory_flag:
352 report_memory('after {} iterations'.format(iteration))
353 report_memory_flag = False
354 timers.log(['forward', 'backward', 'optimizer', 'batch generator',
355 'data loader'],
356 normalizer=args.log_interval)
357 # Checkpointing
358 if args.save and args.save_interval and iteration % args.save_interval == 0:

Callers 1

mainFunction · 0.70

Calls 9

print_rank_0Function · 0.90
report_memoryFunction · 0.90
save_checkpointFunction · 0.90
trainMethod · 0.80
train_stepFunction · 0.70
startMethod · 0.45
elapsedMethod · 0.45
logMethod · 0.45

Tested by

no test coverage detected