Forward step.
(data_iterator, model, args, timers)
| 203 | |
| 204 | |
| 205 | def forward_step(data_iterator, model, args, timers): |
| 206 | """Forward step.""" |
| 207 | |
| 208 | # Get the batch. |
| 209 | timers('batch generator').start() |
| 210 | tokens, types, next_sentence, loss_mask, lm_labels, \ |
| 211 | padding_mask = get_batch(data_iterator, timers) |
| 212 | timers('batch generator').stop() |
| 213 | # Forward model. |
| 214 | output, nsp = model(tokens, types, 1-padding_mask, |
| 215 | checkpoint_activations=args.checkpoint_activations) |
| 216 | |
| 217 | nsp_loss = F.cross_entropy(nsp.view(-1, 2).contiguous().float(), |
| 218 | next_sentence.view(-1).contiguous(), |
| 219 | ignore_index=-1) |
| 220 | |
| 221 | losses = mpu.vocab_parallel_cross_entropy( |
| 222 | output.contiguous().float(), lm_labels.contiguous()) |
| 223 | loss_mask = loss_mask.contiguous() |
| 224 | lm_loss = torch.sum( |
| 225 | losses.view(-1) * loss_mask.view(-1).float()) / loss_mask.sum() |
| 226 | |
| 227 | return lm_loss, nsp_loss |
| 228 | |
| 229 | |
| 230 | def backward_step(optimizer, model, lm_loss, nsp_loss, args): |
no test coverage detected