MCPcopy Index your code
hub / github.com/zai-org/CodeGeeX / train

Function train

codegeex/megatron/training.py:882–1019  ·  view source on GitHub ↗

Train the model function.

(
    forward_step_func,
    valid_forward_step_func,
    model,
    optimizer,
    lr_scheduler,
    train_data_iterator,
    valid_data_iterator,
)

Source from the content-addressed store, hash-verified

880
881
882def train(
883 forward_step_func,
884 valid_forward_step_func,
885 model,
886 optimizer,
887 lr_scheduler,
888 train_data_iterator,
889 valid_data_iterator,
890):
891 """Train the model function."""
892 args = get_args()
893 timers = get_timers()
894
895 # Write args to tensorboard
896 write_args_to_tensorboard()
897
898 if args.wandb_logging:
899 torch.distributed.barrier()
900 print_datetime("before the initialization of wandb")
901 timers("wandb-init").start()
902 if is_last_rank():
903 initialize_wandb_experiment()
904 torch.distributed.barrier()
905 timers("wandb-init").stop()
906 timers.log(["wandb-init"])
907
908 # Turn on training mode which enables dropout.
909 for model_module in model:
910 model_module.train()
911
912 # Tracking loss.
913 total_loss_dict = {}
914
915 # Iterations.
916 iteration = args.iteration
917
918 timers("interval-time").start()
919 print_datetime("before the start of training step")
920 report_memory_flag = True
921
922 while iteration < args.train_iters and (
923 args.train_tokens is None or args.consumed_train_tokens < args.train_tokens
924 ):
925 # print_rank_0(f'=> iteration {iteration}')
926 update_num_microbatches(args.consumed_train_samples)
927 if args.deepspeed:
928 # inform deepspeed of any batch size changes
929 global_batch_size = (
930 mpu.get_data_parallel_world_size()
931 * args.micro_batch_size
932 * get_num_microbatches()
933 )
934 model[0].set_train_batch_size(global_batch_size)
935
936 # print_rank_0(f"==> running train step for iteration {iteration}")
937 loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step(
938 forward_step_func, train_data_iterator, model, optimizer, lr_scheduler
939 )

Callers 1

pretrainFunction · 0.85

Calls 15

get_argsFunction · 0.90
get_timersFunction · 0.90
is_last_rankFunction · 0.90
update_num_microbatchesFunction · 0.90
get_num_microbatchesFunction · 0.90
calc_params_l2_normFunction · 0.90
print_datetimeFunction · 0.85
train_stepFunction · 0.85
training_logFunction · 0.85

Tested by

no test coverage detected