| 385 | |
| 386 | @staticmethod |
| 387 | def backward(ctx, *output_grads): |
| 388 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] |
| 389 | with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): |
| 390 | # Fixes a bug where the first op in run_function modifies the |
| 391 | # Tensor storage in place, which is not allowed for detach()'d |
| 392 | # Tensors. |
| 393 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] |
| 394 | output_tensors = ctx.run_function(*shallow_copies) |
| 395 | input_grads = torch.autograd.grad( |
| 396 | output_tensors, |
| 397 | ctx.input_tensors + ctx.input_params, |
| 398 | output_grads, |
| 399 | allow_unused=True, |
| 400 | ) |
| 401 | del ctx.input_tensors |
| 402 | del ctx.input_params |
| 403 | del output_tensors |
| 404 | return (None, None) + input_grads |