MCPcopy
hub / github.com/InternLM/InternLM / activation_checkpoint

Function activation_checkpoint

internlm/utils/checkpoint.py:152–170  ·  view source on GitHub ↗

Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint. Args: function: Describe the forward pass function. It should know how to handle the input tuples. activation_offload: The variable to check whether we should offload activatio

(function, activation_offload, *args, use_reentrant: bool = True)

Source from the content-addressed store, hash-verified

150
151
152def activation_checkpoint(function, activation_offload, *args, use_reentrant: bool = True):
153 """Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint.
154 Args:
155 function: Describe the forward pass function. It should know how to handle the input tuples.
156 activation_offload: The variable to check whether we should offload activation to cpu
157 args (list): Tuple containing the parameters of the function
158 use_reentrant: Bool type to check if we need to use_reentrant, if use_reentrant=False, there
159 might be more flexibility for user to define there checkpoint function
160 Returns:
161 Output of running function with provided args.
162 """
163 if use_reentrant:
164 return CheckpointFunction.apply(function, activation_offload, *args)
165 else:
166 return _checkpoint_without_reentrant(
167 function,
168 activation_offload,
169 *args,
170 )
171
172
173def _checkpoint_without_reentrant(function, activation_offload=False, *args): # pylint: disable=W1113

Callers 2

forwardMethod · 0.90
_forwardMethod · 0.90

Calls 1

Tested by

no test coverage detected