MCPcopy Index your code
hub / github.com/InternLM/InternLM / backward

Method backward

internlm/utils/checkpoint.py:90–149  ·  view source on GitHub ↗
(ctx, *args)

Source from the content-addressed store, hash-verified

88
89 @staticmethod
90 def backward(ctx, *args):
91 if not torch.autograd._is_checkpoint_valid():
92 raise RuntimeError(
93 "Checkpointing is not compatible with .grad() or when an `inputs` parameter is "
94 "passed to .backward(). Please use .backward() and do not pass its `inputs` argument."
95 )
96 # Copy the list to avoid modifying original list.
97 inputs = list(ctx.inputs)
98 tensor_indices = ctx.tensor_indices
99
100 if ctx.activation_offload:
101 tensors = ctx.tensor_inputs
102 else:
103 tensors = ctx.saved_tensors
104
105 # store the current states
106 bwd_cpu_rng_state = torch.get_rng_state()
107 sync_states()
108 bwd_seed_states = get_states(copy=True)
109 bwd_current_mode = get_current_mode()
110
111 # set the states to what it used to be
112 torch.set_rng_state(ctx.fwd_cpu_rng_state)
113 for parallel_mode, state in ctx.fwd_seed_states.items():
114 set_seed_states(parallel_mode, state)
115 set_mode(ctx.fwd_current_mode)
116 if ctx.activation_offload:
117 tensors = copy_to_device(tensors, ctx.device)
118
119 # Fill in inputs with appropriate saved tensors.
120 for i, idx in enumerate(tensor_indices):
121 inputs[idx] = tensors[i]
122 detached_inputs = detach_variable(tuple(inputs))
123 if ctx.had_autocast_in_fwd:
124 with torch.enable_grad(), torch.cuda.amp.autocast():
125 outputs = ctx.run_function(*detached_inputs)
126 else:
127 with torch.enable_grad():
128 outputs = ctx.run_function(*detached_inputs)
129
130 if isinstance(outputs, torch.Tensor):
131 outputs = (outputs,)
132 # recover the rng states
133 torch.set_rng_state(bwd_cpu_rng_state)
134 for parallel_mode, state in bwd_seed_states.items():
135 set_seed_states(parallel_mode, state)
136 set_mode(bwd_current_mode)
137
138 # run backward() with only tensor that requires grad
139 outputs_with_grad = []
140 args_with_grad = []
141 for i in range(len(outputs)):
142 if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
143 outputs_with_grad.append(outputs[i])
144 args_with_grad.append(args[i])
145 if len(outputs_with_grad) == 0:
146 raise RuntimeError("none of output has requires_grad=True," " this checkpoint() is not necessary")
147 torch.autograd.backward(outputs_with_grad, args_with_grad)

Callers 1

Calls 6

sync_statesFunction · 0.90
get_statesFunction · 0.90
get_current_modeFunction · 0.90
set_seed_statesFunction · 0.90
set_modeFunction · 0.90
copy_to_deviceFunction · 0.85

Tested by

no test coverage detected