MCPcopy
hub / github.com/deepspeedai/DeepSpeedExamples / train

Function train

Megatron-LM/pretrain_gpt2.py:391–465  ·  view source on GitHub ↗

Train the model.

(model, optimizer, lr_scheduler,
          train_data_iterator, val_data_iterator, timers, args)

Source from the content-addressed store, hash-verified

389
390
391def train(model, optimizer, lr_scheduler,
392 train_data_iterator, val_data_iterator, timers, args):
393 """Train the model."""
394
395 # Turn on training mode which enables dropout.
396 model.train()
397
398 # Tracking loss.
399 total_lm_loss = 0.0
400
401 # Iterations.
402 iteration = args.iteration
403 skipped_iters = 0
404
405 timers('interval time').start()
406 report_memory_flag = True
407 while iteration < args.train_iters:
408
409 lm_loss, skipped_iter = train_step(train_data_iterator,
410 model,
411 optimizer,
412 lr_scheduler,
413 args, timers)
414 skipped_iters += skipped_iter
415 iteration += 1
416
417 # Update losses.
418 total_lm_loss += lm_loss.data.detach().float()
419
420 # Logging.
421 if iteration % args.log_interval == 0:
422 learning_rate = optimizer.param_groups[0]['lr']
423 avg_lm_loss = total_lm_loss.item() / args.log_interval
424 elapsed_time = timers('interval time').elapsed()
425 log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
426 args.train_iters)
427 log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
428 elapsed_time * 1000.0 / args.log_interval)
429 log_string += ' learning rate {:.3E} |'.format(learning_rate)
430 log_string += ' lm loss {:.6E} |'.format(avg_lm_loss)
431 if args.fp16:
432 log_string += ' loss scale {:.1f} |'.format(
433 optimizer.cur_scale if args.deepspeed else optimizer.loss_scale)
434 print_rank_0(log_string)
435 total_lm_loss = 0.0
436 if report_memory_flag:
437 report_memory('after {} iterations'.format(iteration))
438 report_memory_flag = False
439 if USE_TORCH_DDP:
440 timers.log(['forward', 'backward', 'optimizer',
441 'batch generator', 'data loader'],
442 normalizer=args.log_interval)
443 else:
444 timers.log(['forward', 'backward', 'allreduce', 'optimizer',
445 'batch generator', 'data loader'],
446 normalizer=args.log_interval)
447 # Checkpointing
448 if args.save and args.save_interval and iteration % args.save_interval == 0:

Callers 1

mainFunction · 0.70

Calls 9

print_rank_0Function · 0.90
report_memoryFunction · 0.90
save_checkpointFunction · 0.90
trainMethod · 0.80
train_stepFunction · 0.70
startMethod · 0.45
elapsedMethod · 0.45
logMethod · 0.45

Tested by

no test coverage detected