(ctx, run_function, activation_offload=False, *args)
| 41 | |
| 42 | @staticmethod |
| 43 | def forward(ctx, run_function, activation_offload=False, *args): # pylint: disable=W1113 |
| 44 | check_backward_validity(args) |
| 45 | ctx.run_function = run_function |
| 46 | ctx.activation_offload = activation_offload |
| 47 | ctx.device = get_current_device() |
| 48 | |
| 49 | # preserve rng states |
| 50 | ctx.fwd_cpu_rng_state = torch.get_rng_state() |
| 51 | sync_states() |
| 52 | ctx.fwd_seed_states = get_states(copy=True) |
| 53 | ctx.fwd_current_mode = get_current_mode() |
| 54 | |
| 55 | if hasattr(torch, "is_autocast_enabled"): |
| 56 | ctx.had_autocast_in_fwd = torch.is_autocast_enabled() |
| 57 | else: |
| 58 | ctx.had_autocast_in_fwd = False |
| 59 | |
| 60 | if activation_offload: |
| 61 | inputs_cuda = copy_to_device(args, ctx.device) |
| 62 | else: |
| 63 | inputs_cuda = args |
| 64 | |
| 65 | with torch.no_grad(): |
| 66 | outputs = run_function(*inputs_cuda) |
| 67 | # Save non-tensor inputs in ctx, keep a placeholder None for tensors |
| 68 | # to be filled out during the backward. |
| 69 | ctx.inputs = [] |
| 70 | ctx.tensor_indices = [] |
| 71 | tensor_inputs = [] |
| 72 | for i, arg in enumerate(args): |
| 73 | if torch.is_tensor(arg): |
| 74 | if activation_offload: |
| 75 | tensor_inputs.append(copy_to_device(arg, "cpu")) |
| 76 | else: |
| 77 | tensor_inputs.append(arg) |
| 78 | ctx.tensor_indices.append(i) |
| 79 | ctx.inputs.append(None) |
| 80 | else: |
| 81 | ctx.inputs.append(arg) |
| 82 | |
| 83 | if activation_offload: |
| 84 | ctx.tensor_inputs = tensor_inputs |
| 85 | else: |
| 86 | ctx.save_for_backward(*tensor_inputs) |
| 87 | return outputs |
| 88 | |
| 89 | @staticmethod |
| 90 | def backward(ctx, *args): |
nothing calls this directly
no test coverage detected