MCPcopy
hub / github.com/deepspeedai/DeepSpeedExamples / forward_step

Function forward_step

Megatron-LM/evaluate_gpt2.py:188–218  ·  view source on GitHub ↗

Forward step.

(data_iterator, model, args, timers)

Source from the content-addressed store, hash-verified

186
187
188def forward_step(data_iterator, model, args, timers):
189 """Forward step."""
190
191 # Get the batch.
192 timers('batch generator').start()
193 batch = get_batch(data_iterator, args, timers)
194 if batch is None:
195 return None
196 tokens, lm_labels, attention_mask, position_ids, loss_mask = batch
197 timers('batch generator').stop()
198 # Forward model.
199 if args.eval_hf:
200 output, _ = model(tokens)
201 else:
202 output = model(tokens, position_ids, attention_mask)
203
204 if not args.cloze_eval:
205 #losses = torch.nn.CrossEntropyLoss(reduce=False)(
206 losses = mpu.vocab_parallel_cross_entropy(
207 output.contiguous().float(), lm_labels.contiguous())
208 loss_mask = loss_mask.contiguous()
209 loss_mask = loss_mask.view(-1)
210 lm_loss = torch.sum(
211 losses.view(-1) * loss_mask.float())
212 else:
213 outputs = torch.argmax(output, -1).contiguous().view(-1)
214 acc = (outputs == lm_labels.contiguous().view(-1)).float()
215 loss_mask = loss_mask.contiguous().view(-1).float()
216 lm_loss = torch.sum(acc * loss_mask)
217
218 return lm_loss
219
220
221def evaluate(data_loader, model, args, timers,

Callers 1

evaluateFunction · 0.70

Calls 3

get_batchFunction · 0.70
startMethod · 0.45
stopMethod · 0.45

Tested by

no test coverage detected