Forward step.
(data_iterator, model, args, timers)
| 186 | |
| 187 | |
| 188 | def forward_step(data_iterator, model, args, timers): |
| 189 | """Forward step.""" |
| 190 | |
| 191 | # Get the batch. |
| 192 | timers('batch generator').start() |
| 193 | batch = get_batch(data_iterator, args, timers) |
| 194 | if batch is None: |
| 195 | return None |
| 196 | tokens, lm_labels, attention_mask, position_ids, loss_mask = batch |
| 197 | timers('batch generator').stop() |
| 198 | # Forward model. |
| 199 | if args.eval_hf: |
| 200 | output, _ = model(tokens) |
| 201 | else: |
| 202 | output = model(tokens, position_ids, attention_mask) |
| 203 | |
| 204 | if not args.cloze_eval: |
| 205 | #losses = torch.nn.CrossEntropyLoss(reduce=False)( |
| 206 | losses = mpu.vocab_parallel_cross_entropy( |
| 207 | output.contiguous().float(), lm_labels.contiguous()) |
| 208 | loss_mask = loss_mask.contiguous() |
| 209 | loss_mask = loss_mask.view(-1) |
| 210 | lm_loss = torch.sum( |
| 211 | losses.view(-1) * loss_mask.float()) |
| 212 | else: |
| 213 | outputs = torch.argmax(output, -1).contiguous().view(-1) |
| 214 | acc = (outputs == lm_labels.contiguous().view(-1)).float() |
| 215 | loss_mask = loss_mask.contiguous().view(-1).float() |
| 216 | lm_loss = torch.sum(acc * loss_mask) |
| 217 | |
| 218 | return lm_loss |
| 219 | |
| 220 | |
| 221 | def evaluate(data_loader, model, args, timers, |