(args, index, model, optimizer, finetune=False)
| 175 | return dataset_picker, dataloaders, sum(datalengths) |
| 176 | |
| 177 | def train(args, index, model, optimizer, finetune=False): |
| 178 | global global_step |
| 179 | global global_data_samples |
| 180 | global last_global_step_from_restore |
| 181 | |
| 182 | dataset_picker, dataloaders, total_length = get_train_dataset(args, index, finetune) |
| 183 | current_data_sample_count = global_data_samples |
| 184 | global_data_samples += total_length |
| 185 | config = args.config |
| 186 | logger = args.logger |
| 187 | print('total_length', total_length, 'global_data_samples', global_data_samples) |
| 188 | |
| 189 | model.train() |
| 190 | |
| 191 | epoch_step = 0 |
| 192 | for step, dataset_type in enumerate(tqdm(dataset_picker, smoothing=1)): |
| 193 | try: |
| 194 | batch = next(dataloaders[dataset_type]) |
| 195 | batch = tuple(t.to(args.device) for t in batch) # Move to GPU |
| 196 | |
| 197 | # Calculate forward pass |
| 198 | loss = model.network(batch) |
| 199 | unscaled_loss = loss.item() |
| 200 | current_data_sample_count += (args.train_micro_batch_size_per_gpu * dist.get_world_size()) |
| 201 | |
| 202 | model.network.backward(loss) |
| 203 | |
| 204 | if model.network.is_gradient_accumulation_boundary(): |
| 205 | if args.fp16: |
| 206 | # modify learning rate with special warm up BERT uses |
| 207 | # if args.fp16 is False, BertAdam is used that handles this automatically |
| 208 | lr_this_step = update_learning_rate(config, global_step, optimizer) |
| 209 | |
| 210 | report_step_metrics(args, lr_this_step, unscaled_loss, global_step, current_data_sample_count) |
| 211 | |
| 212 | model.network.step() |
| 213 | |
| 214 | report_lamb_coefficients(args, optimizer) |
| 215 | global_step += 1 |
| 216 | epoch_step += 1 |
| 217 | else: |
| 218 | # Call DeepSpeed engine step on micro steps |
| 219 | model.network.step() |
| 220 | |
| 221 | except StopIteration: |
| 222 | continue |
| 223 | |
| 224 | current_global_step = global_step - last_global_step_from_restore |
| 225 | if is_time_to_exit(args=args, |
| 226 | epoch_steps=epoch_step, |
| 227 | global_steps=current_global_step): |
| 228 | print(f'Warning: Early epoch termination due to max steps limit, epoch step ={epoch_step}, global step = {current_global_step}, epoch = {index+1}') |
| 229 | break |
| 230 | |
| 231 | # Run Validation Loss |
| 232 | if not finetune and args.max_seq_length == 512: |
| 233 | logger.info(f"TRAIN BATCH SIZE: {args.train_micro_batch_size_per_gpu}") |
| 234 | pretrain_validation(args, index, model) |
no test coverage detected