MCPcopy
hub / github.com/EleutherAI/gpt-neox / forward

Method forward

megatron/model/utils.py:124–189  ·  view source on GitHub ↗
(
        self, forward_input, curriculum_seqlen=None, labels=None, neox_args=None
    )

Source from the content-addressed store, hash-verified

122 _set_use_cache(self.sequential, False)
123
124 def forward(
125 self, forward_input, curriculum_seqlen=None, labels=None, neox_args=None
126 ):
127
128 if (
129 curriculum_seqlen is not None
130 and isinstance(forward_input, tuple)
131 and len(forward_input) == 3
132 ):
133 neox_args.update_value("curriculum_seqlen", curriculum_seqlen)
134 tokens = forward_input[0]
135 input_ids = forward_input[1]
136 attention_mask = forward_input[2]
137 if curriculum_seqlen < input_ids.size()[1]:
138 # seqlen-based curriculum learning
139 # input_ids, position_ids, labels have size [batch size, seqlen]
140 input_ids = input_ids[:, :curriculum_seqlen].contiguous()
141 tokens = tokens[:, :curriculum_seqlen].contiguous()
142 # position_ids = position_ids[:, :curriculum_seqlen].contiguous()
143 if labels is not None:
144 labels = labels[:, :curriculum_seqlen].contiguous()
145 # attention_mask has size [1, 1, seqlen, seqlen]
146 attention_mask = attention_mask[
147 :, :, :curriculum_seqlen, :curriculum_seqlen
148 ].contiguous()
149 forward_input = (tokens, input_ids, attention_mask)
150
151 def exec_range_func(start, end):
152 """Helper function to be used with checkpoint()
153 Adapted from torch.utils.checkpoint:checkpoint_sequential()
154 """
155
156 def exec_func(*inputs):
157 # Single tensor inputs need to be unwrapped
158 if len(inputs) == 1:
159 inputs = inputs[0]
160 for idx, layer in enumerate(self.sequential[start:end]):
161 inputs = layer(inputs)
162 return inputs
163
164 return exec_func
165
166 if self.activation_checkpoint_interval == 0:
167 func = exec_range_func(0, len(self.sequential))
168 x = func(forward_input)
169 else:
170 num_layers = len(self.sequential)
171 x = forward_input
172 for start_idx in range(0, num_layers, self.activation_checkpoint_interval):
173 end_idx = min(
174 start_idx + self.activation_checkpoint_interval, num_layers
175 )
176
177 funcs = self.sequential[start_idx:end_idx]
178 # Since we either pass tensors or tuples of tensors without unpacking, we
179 # need to be careful not to double-wrap tensors with tuple.
180 if not isinstance(x, tuple):
181 x = (x,)

Callers

nothing calls this directly

Calls 3

_is_checkpointableMethod · 0.95
update_valueMethod · 0.80
sizeMethod · 0.80

Tested by

no test coverage detected