| 151 | |
| 152 | @staticmethod |
| 153 | def backward(ctx, *output_grads): |
| 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] |
| 155 | with th.enable_grad(): |
| 156 | # Fixes a bug where the first op in run_function modifies the |
| 157 | # Tensor storage in place, which is not allowed for detach()'d |
| 158 | # Tensors. |
| 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] |
| 160 | output_tensors = ctx.run_function(*shallow_copies) |
| 161 | input_grads = th.autograd.grad( |
| 162 | output_tensors, |
| 163 | ctx.input_tensors + ctx.input_params, |
| 164 | output_grads, |
| 165 | allow_unused=True, |
| 166 | ) |
| 167 | del ctx.input_tensors |
| 168 | del ctx.input_params |
| 169 | del output_tensors |
| 170 | return (None, None) + input_grads |