Train the model function.
(
forward_step_func,
valid_forward_step_func,
model,
optimizer,
lr_scheduler,
train_data_iterator,
valid_data_iterator,
)
| 880 | |
| 881 | |
| 882 | def train( |
| 883 | forward_step_func, |
| 884 | valid_forward_step_func, |
| 885 | model, |
| 886 | optimizer, |
| 887 | lr_scheduler, |
| 888 | train_data_iterator, |
| 889 | valid_data_iterator, |
| 890 | ): |
| 891 | """Train the model function.""" |
| 892 | args = get_args() |
| 893 | timers = get_timers() |
| 894 | |
| 895 | # Write args to tensorboard |
| 896 | write_args_to_tensorboard() |
| 897 | |
| 898 | if args.wandb_logging: |
| 899 | torch.distributed.barrier() |
| 900 | print_datetime("before the initialization of wandb") |
| 901 | timers("wandb-init").start() |
| 902 | if is_last_rank(): |
| 903 | initialize_wandb_experiment() |
| 904 | torch.distributed.barrier() |
| 905 | timers("wandb-init").stop() |
| 906 | timers.log(["wandb-init"]) |
| 907 | |
| 908 | # Turn on training mode which enables dropout. |
| 909 | for model_module in model: |
| 910 | model_module.train() |
| 911 | |
| 912 | # Tracking loss. |
| 913 | total_loss_dict = {} |
| 914 | |
| 915 | # Iterations. |
| 916 | iteration = args.iteration |
| 917 | |
| 918 | timers("interval-time").start() |
| 919 | print_datetime("before the start of training step") |
| 920 | report_memory_flag = True |
| 921 | |
| 922 | while iteration < args.train_iters and ( |
| 923 | args.train_tokens is None or args.consumed_train_tokens < args.train_tokens |
| 924 | ): |
| 925 | # print_rank_0(f'=> iteration {iteration}') |
| 926 | update_num_microbatches(args.consumed_train_samples) |
| 927 | if args.deepspeed: |
| 928 | # inform deepspeed of any batch size changes |
| 929 | global_batch_size = ( |
| 930 | mpu.get_data_parallel_world_size() |
| 931 | * args.micro_batch_size |
| 932 | * get_num_microbatches() |
| 933 | ) |
| 934 | model[0].set_train_batch_size(global_batch_size) |
| 935 | |
| 936 | # print_rank_0(f"==> running train step for iteration {iteration}") |
| 937 | loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step( |
| 938 | forward_step_func, train_data_iterator, model, optimizer, lr_scheduler |
| 939 | ) |
no test coverage detected