Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint. Args: function: Describe the forward pass function. It should know how to handle the input tuples. activation_offload: The variable to check whether we should offload activatio
(function, activation_offload, *args, use_reentrant: bool = True)
| 150 | |
| 151 | |
| 152 | def activation_checkpoint(function, activation_offload, *args, use_reentrant: bool = True): |
| 153 | """Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint. |
| 154 | Args: |
| 155 | function: Describe the forward pass function. It should know how to handle the input tuples. |
| 156 | activation_offload: The variable to check whether we should offload activation to cpu |
| 157 | args (list): Tuple containing the parameters of the function |
| 158 | use_reentrant: Bool type to check if we need to use_reentrant, if use_reentrant=False, there |
| 159 | might be more flexibility for user to define there checkpoint function |
| 160 | Returns: |
| 161 | Output of running function with provided args. |
| 162 | """ |
| 163 | if use_reentrant: |
| 164 | return CheckpointFunction.apply(function, activation_offload, *args) |
| 165 | else: |
| 166 | return _checkpoint_without_reentrant( |
| 167 | function, |
| 168 | activation_offload, |
| 169 | *args, |
| 170 | ) |
| 171 | |
| 172 | |
| 173 | def _checkpoint_without_reentrant(function, activation_offload=False, *args): # pylint: disable=W1113 |
no test coverage detected