r"""Execute backward pass on the loss Arguments: loss: Torch tensor on which to execute backward propagation retain_graph: bool, default: false forward on user defined choice of retain_graph scale_wrt_gas: bool, default: true
(self, loss, retain_graph=False, scale_wrt_gas=True)
| 2850 | |
| 2851 | @instrument_w_nvtx |
| 2852 | def backward(self, loss, retain_graph=False, scale_wrt_gas=True): |
| 2853 | r"""Execute backward pass on the loss |
| 2854 | Arguments: |
| 2855 | loss: Torch tensor on which to execute backward propagation |
| 2856 | retain_graph: bool, default: false |
| 2857 | forward on user defined choice of retain_graph |
| 2858 | scale_wrt_gas: bool, default: true |
| 2859 | whether to scale gradients and return value by gradient accumulation steps |
| 2860 | """ |
| 2861 | assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \ |
| 2862 | "must provide optimizer during init in order to use backward" |
| 2863 | assert maybe_loss_for_backward( |
| 2864 | loss), "loss must be a scalar tensor. If you need to pass output gradients, backward() of output tensors" |
| 2865 | |
| 2866 | self._running_engine_backward = True |
| 2867 | # Store scale_wrt_gas so the hook can respect it |
| 2868 | self._scale_wrt_gas = scale_wrt_gas |
| 2869 | |
| 2870 | # Set flag to prevent hooks from firing (we'll manually call prologue/epilogue) |
| 2871 | backward_kwargs = {"retain_graph": retain_graph} |
| 2872 | if self.eigenvalue_enabled(): |
| 2873 | backward_kwargs["create_graph"] = True |
| 2874 | backward_kwargs["retain_graph"] = True |
| 2875 | |
| 2876 | # Used only for return value |
| 2877 | gas_scaled_loss = loss / self.gradient_accumulation_steps() if scale_wrt_gas else loss |
| 2878 | |
| 2879 | # TODO: handle these scaling with direct calls to loss.backward() |
| 2880 | if isinstance(self.optimizer, ZeROOptimizer): |
| 2881 | loss = self.optimizer.scale_if_loss(loss) |
| 2882 | elif self.torch_autocast_z0_gradscaler: |
| 2883 | loss = self.torch_autocast_z0_gradscaler.scale(loss) |
| 2884 | |
| 2885 | with compiled_autograd(self._is_compiled_autograd_enabled, self._compile_kwargs): |
| 2886 | if self.zero_optimization() or not self.amp_enabled(): |
| 2887 | loss.backward(**backward_kwargs) |
| 2888 | elif self.amp_enabled(): |
| 2889 | # AMP requires delaying unscale when inside gradient accumulation boundaries |
| 2890 | # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations |
| 2891 | delay_unscale = not self.is_gradient_accumulation_boundary() |
| 2892 | with amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss: |
| 2893 | scaled_loss.backward(**backward_kwargs) |
| 2894 | |
| 2895 | # backward_epilogue is not called in a hook when self._support_torch_style_backward is False |
| 2896 | self._backward_epilogue() |
| 2897 | |
| 2898 | self._running_engine_backward = False |
| 2899 | |
| 2900 | return gas_scaled_loss |
| 2901 | |
| 2902 | def is_gradient_accumulation_boundary(self): |
| 2903 | """ |
nothing calls this directly
no test coverage detected