Forward step.
(data_iterator, model, args, timers)
| 273 | |
| 274 | |
| 275 | def forward_step(data_iterator, model, args, timers): |
| 276 | """Forward step.""" |
| 277 | |
| 278 | # Get the batch. |
| 279 | timers('batch generator').start() |
| 280 | tokens, labels, loss_mask, attention_mask, position_ids = get_batch( |
| 281 | data_iterator, args, timers) |
| 282 | timers('batch generator').stop() |
| 283 | |
| 284 | # Forward model. |
| 285 | output = model(tokens, position_ids, attention_mask) |
| 286 | losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), |
| 287 | labels) |
| 288 | loss_mask = loss_mask.view(-1) |
| 289 | loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() |
| 290 | |
| 291 | return loss |
| 292 | |
| 293 | |
| 294 | def backward_step(optimizer, model, lm_loss, args, timers): |
no test coverage detected