Backward step.
(optimizer, model, lm_loss, args, timers)
| 292 | |
| 293 | |
| 294 | def backward_step(optimizer, model, lm_loss, args, timers): |
| 295 | """Backward step.""" |
| 296 | |
| 297 | # Total loss. |
| 298 | loss = lm_loss |
| 299 | |
| 300 | # Backward pass. |
| 301 | if args.deepspeed: |
| 302 | model.backward(loss) |
| 303 | else: |
| 304 | optimizer.zero_grad() |
| 305 | if args.fp16: |
| 306 | optimizer.backward(loss, update_master_grads=False) |
| 307 | else: |
| 308 | loss.backward() |
| 309 | |
| 310 | # Reduce across processes. |
| 311 | lm_loss_reduced = lm_loss |
| 312 | |
| 313 | reduced_losses = lm_loss.view(1) |
| 314 | |
| 315 | if args.deepspeed: |
| 316 | # DeepSpeed backward propagation already addressed all reduce communication. |
| 317 | # Reset the timer to avoid breaking timer logs below. |
| 318 | timers('allreduce').reset() |
| 319 | else: |
| 320 | torch.distributed.all_reduce(reduced_losses.data) |
| 321 | reduced_losses.data = reduced_losses.data / args.world_size |
| 322 | if not USE_TORCH_DDP: |
| 323 | timers('allreduce').start() |
| 324 | model.allreduce_params(reduce_after=False, |
| 325 | fp32_allreduce=args.fp32_allreduce) |
| 326 | timers('allreduce').stop() |
| 327 | |
| 328 | lm_loss_reduced = reduced_losses |
| 329 | |
| 330 | # Update master gradients. |
| 331 | if not args.deepspeed: |
| 332 | if args.fp16: |
| 333 | optimizer.update_master_grads() |
| 334 | |
| 335 | # Clipping gradients helps prevent the exploding gradient. |
| 336 | if args.clip_grad > 0: |
| 337 | if not args.fp16: |
| 338 | mpu.clip_grad_norm(model.parameters(), args.clip_grad) |
| 339 | else: |
| 340 | optimizer.clip_master_grads(args.clip_grad) |
| 341 | |
| 342 | return lm_loss_reduced |
| 343 | |
| 344 | def see_memory_usage(message, force=False): |
| 345 | if not force: |
no test coverage detected