MCPcopy
hub / github.com/InternLM/InternLM / forward

Method forward

internlm/utils/checkpoint.py:43–87  ·  view source on GitHub ↗
(ctx, run_function, activation_offload=False, *args)

Source from the content-addressed store, hash-verified

41
42 @staticmethod
43 def forward(ctx, run_function, activation_offload=False, *args): # pylint: disable=W1113
44 check_backward_validity(args)
45 ctx.run_function = run_function
46 ctx.activation_offload = activation_offload
47 ctx.device = get_current_device()
48
49 # preserve rng states
50 ctx.fwd_cpu_rng_state = torch.get_rng_state()
51 sync_states()
52 ctx.fwd_seed_states = get_states(copy=True)
53 ctx.fwd_current_mode = get_current_mode()
54
55 if hasattr(torch, "is_autocast_enabled"):
56 ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
57 else:
58 ctx.had_autocast_in_fwd = False
59
60 if activation_offload:
61 inputs_cuda = copy_to_device(args, ctx.device)
62 else:
63 inputs_cuda = args
64
65 with torch.no_grad():
66 outputs = run_function(*inputs_cuda)
67 # Save non-tensor inputs in ctx, keep a placeholder None for tensors
68 # to be filled out during the backward.
69 ctx.inputs = []
70 ctx.tensor_indices = []
71 tensor_inputs = []
72 for i, arg in enumerate(args):
73 if torch.is_tensor(arg):
74 if activation_offload:
75 tensor_inputs.append(copy_to_device(arg, "cpu"))
76 else:
77 tensor_inputs.append(arg)
78 ctx.tensor_indices.append(i)
79 ctx.inputs.append(None)
80 else:
81 ctx.inputs.append(arg)
82
83 if activation_offload:
84 ctx.tensor_inputs = tensor_inputs
85 else:
86 ctx.save_for_backward(*tensor_inputs)
87 return outputs
88
89 @staticmethod
90 def backward(ctx, *args):

Callers

nothing calls this directly

Calls 5

sync_statesFunction · 0.90
get_statesFunction · 0.90
get_current_modeFunction · 0.90
get_current_deviceFunction · 0.85
copy_to_deviceFunction · 0.85

Tested by

no test coverage detected