(ctx, *args)
| 88 | |
| 89 | @staticmethod |
| 90 | def backward(ctx, *args): |
| 91 | if not torch.autograd._is_checkpoint_valid(): |
| 92 | raise RuntimeError( |
| 93 | "Checkpointing is not compatible with .grad() or when an `inputs` parameter is " |
| 94 | "passed to .backward(). Please use .backward() and do not pass its `inputs` argument." |
| 95 | ) |
| 96 | # Copy the list to avoid modifying original list. |
| 97 | inputs = list(ctx.inputs) |
| 98 | tensor_indices = ctx.tensor_indices |
| 99 | |
| 100 | if ctx.activation_offload: |
| 101 | tensors = ctx.tensor_inputs |
| 102 | else: |
| 103 | tensors = ctx.saved_tensors |
| 104 | |
| 105 | # store the current states |
| 106 | bwd_cpu_rng_state = torch.get_rng_state() |
| 107 | sync_states() |
| 108 | bwd_seed_states = get_states(copy=True) |
| 109 | bwd_current_mode = get_current_mode() |
| 110 | |
| 111 | # set the states to what it used to be |
| 112 | torch.set_rng_state(ctx.fwd_cpu_rng_state) |
| 113 | for parallel_mode, state in ctx.fwd_seed_states.items(): |
| 114 | set_seed_states(parallel_mode, state) |
| 115 | set_mode(ctx.fwd_current_mode) |
| 116 | if ctx.activation_offload: |
| 117 | tensors = copy_to_device(tensors, ctx.device) |
| 118 | |
| 119 | # Fill in inputs with appropriate saved tensors. |
| 120 | for i, idx in enumerate(tensor_indices): |
| 121 | inputs[idx] = tensors[i] |
| 122 | detached_inputs = detach_variable(tuple(inputs)) |
| 123 | if ctx.had_autocast_in_fwd: |
| 124 | with torch.enable_grad(), torch.cuda.amp.autocast(): |
| 125 | outputs = ctx.run_function(*detached_inputs) |
| 126 | else: |
| 127 | with torch.enable_grad(): |
| 128 | outputs = ctx.run_function(*detached_inputs) |
| 129 | |
| 130 | if isinstance(outputs, torch.Tensor): |
| 131 | outputs = (outputs,) |
| 132 | # recover the rng states |
| 133 | torch.set_rng_state(bwd_cpu_rng_state) |
| 134 | for parallel_mode, state in bwd_seed_states.items(): |
| 135 | set_seed_states(parallel_mode, state) |
| 136 | set_mode(bwd_current_mode) |
| 137 | |
| 138 | # run backward() with only tensor that requires grad |
| 139 | outputs_with_grad = [] |
| 140 | args_with_grad = [] |
| 141 | for i in range(len(outputs)): |
| 142 | if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: |
| 143 | outputs_with_grad.append(outputs[i]) |
| 144 | args_with_grad.append(args[i]) |
| 145 | if len(outputs_with_grad) == 0: |
| 146 | raise RuntimeError("none of output has requires_grad=True," " this checkpoint() is not necessary") |
| 147 | torch.autograd.backward(outputs_with_grad, args_with_grad) |
no test coverage detected