(function, activation_offload=False, *args)
| 171 | |
| 172 | |
| 173 | def _checkpoint_without_reentrant(function, activation_offload=False, *args): # pylint: disable=W1113 |
| 174 | # store rng_state |
| 175 | fwd_cpu_state = torch.get_rng_state() |
| 176 | sync_states() |
| 177 | fwd_seed_states = get_states(copy=True) |
| 178 | fwd_current_mode = get_current_mode() |
| 179 | |
| 180 | # check if use autocast |
| 181 | if hasattr(torch, "is_autocast_enabled"): |
| 182 | has_autocast_in_fwd = torch.is_autocast_enabled() |
| 183 | else: |
| 184 | has_autocast_in_fwd = False |
| 185 | |
| 186 | # using WeakKeyDictionary to store all the activation the first time we call unpack |
| 187 | storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() |
| 188 | weak_holder_list = [] |
| 189 | |
| 190 | # class for weakref.ref |
| 191 | class Holder: |
| 192 | pass |
| 193 | |
| 194 | # return a Holder object for later unpack process |
| 195 | def pack(): |
| 196 | res = Holder() |
| 197 | weak_holder_list.append(weakref.ref(res)) |
| 198 | return res |
| 199 | |
| 200 | # unpack hook |
| 201 | def unpack(x): |
| 202 | unpack_counter = 0 |
| 203 | |
| 204 | # re-compute all the activation inside the function when we first call unpack |
| 205 | if len(storage) == 0: |
| 206 | |
| 207 | def inner_pack(inner): |
| 208 | nonlocal unpack_counter |
| 209 | unpack_counter += 1 |
| 210 | |
| 211 | # If the holder went out of scope, the SavedVariable is dead and so |
| 212 | # the value will never be read from the storage. Skip filling it. |
| 213 | if weak_holder_list[unpack_counter - 1]() is None: |
| 214 | return |
| 215 | |
| 216 | # Use detach here to ensure we don't keep the temporary autograd |
| 217 | # graph created during the second forward |
| 218 | storage[weak_holder_list[unpack_counter - 1]()] = inner.detach() |
| 219 | return |
| 220 | |
| 221 | def inner_unpack(packed): |
| 222 | raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.") |
| 223 | |
| 224 | # restore rng state |
| 225 | torch.set_rng_state(fwd_cpu_state) |
| 226 | for parallel_mode, state in fwd_seed_states.items(): |
| 227 | set_seed_states(parallel_mode, state) |
| 228 | set_mode(fwd_current_mode) |
| 229 | |
| 230 | # reload arg into device if needed |
no test coverage detected