Backward step through passed-in output tensor. If last stage, output_tensor_grad is None, otherwise gradient of loss with respect to stage's output tensor. Returns gradient of loss with respect to input tensor (None if first stage).
(
optimizer, input_tensor, output_tensor, output_tensor_grad, model=None
)
| 70 | |
| 71 | |
| 72 | def backward_step( |
| 73 | optimizer, input_tensor, output_tensor, output_tensor_grad, model=None |
| 74 | ): |
| 75 | """Backward step through passed-in output tensor. |
| 76 | |
| 77 | If last stage, output_tensor_grad is None, otherwise gradient of loss |
| 78 | with respect to stage's output tensor. |
| 79 | |
| 80 | Returns gradient of loss with respect to input tensor (None if first |
| 81 | stage).""" |
| 82 | args = get_args() |
| 83 | |
| 84 | if args.deepspeed: |
| 85 | assert model is not None |
| 86 | |
| 87 | timers = get_timers() |
| 88 | timers("backward-compute").start() |
| 89 | |
| 90 | # Retain the grad on the input_tensor. |
| 91 | if input_tensor is not None: |
| 92 | input_tensor.retain_grad() |
| 93 | |
| 94 | if args.deepspeed: |
| 95 | model.backward(output_tensor) |
| 96 | else: |
| 97 | # Backward pass. |
| 98 | if output_tensor_grad is None: |
| 99 | output_tensor = optimizer.scale_loss(output_tensor) |
| 100 | torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) |
| 101 | |
| 102 | # Collect the grad of the input_tensor. |
| 103 | input_tensor_grad = None |
| 104 | if input_tensor is not None: |
| 105 | input_tensor_grad = input_tensor.grad |
| 106 | |
| 107 | timers("backward-compute").stop() |
| 108 | |
| 109 | return input_tensor_grad |
| 110 | |
| 111 | |
| 112 | @contextmanager |
no test coverage detected