MCPcopy
hub / github.com/InternLM/InternLM / _checkpoint_without_reentrant

Function _checkpoint_without_reentrant

internlm/utils/checkpoint.py:173–269  ·  view source on GitHub ↗
(function, activation_offload=False, *args)

Source from the content-addressed store, hash-verified

171
172
173def _checkpoint_without_reentrant(function, activation_offload=False, *args): # pylint: disable=W1113
174 # store rng_state
175 fwd_cpu_state = torch.get_rng_state()
176 sync_states()
177 fwd_seed_states = get_states(copy=True)
178 fwd_current_mode = get_current_mode()
179
180 # check if use autocast
181 if hasattr(torch, "is_autocast_enabled"):
182 has_autocast_in_fwd = torch.is_autocast_enabled()
183 else:
184 has_autocast_in_fwd = False
185
186 # using WeakKeyDictionary to store all the activation the first time we call unpack
187 storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
188 weak_holder_list = []
189
190 # class for weakref.ref
191 class Holder:
192 pass
193
194 # return a Holder object for later unpack process
195 def pack():
196 res = Holder()
197 weak_holder_list.append(weakref.ref(res))
198 return res
199
200 # unpack hook
201 def unpack(x):
202 unpack_counter = 0
203
204 # re-compute all the activation inside the function when we first call unpack
205 if len(storage) == 0:
206
207 def inner_pack(inner):
208 nonlocal unpack_counter
209 unpack_counter += 1
210
211 # If the holder went out of scope, the SavedVariable is dead and so
212 # the value will never be read from the storage. Skip filling it.
213 if weak_holder_list[unpack_counter - 1]() is None:
214 return
215
216 # Use detach here to ensure we don't keep the temporary autograd
217 # graph created during the second forward
218 storage[weak_holder_list[unpack_counter - 1]()] = inner.detach()
219 return
220
221 def inner_unpack(packed):
222 raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.")
223
224 # restore rng state
225 torch.set_rng_state(fwd_cpu_state)
226 for parallel_mode, state in fwd_seed_states.items():
227 set_seed_states(parallel_mode, state)
228 set_mode(fwd_current_mode)
229
230 # reload arg into device if needed

Callers 1

activation_checkpointFunction · 0.85

Calls 4

sync_statesFunction · 0.90
get_statesFunction · 0.90
get_current_modeFunction · 0.90
get_current_deviceFunction · 0.85

Tested by

no test coverage detected