Calculate correct over total answers and return prediction if the `output_predictions` is true.
(self, model, dataloader, example_dict, args)
| 265 | self.processors.append(processor) |
| 266 | |
| 267 | def evaluate(self, model, dataloader, example_dict, args): |
| 268 | """Calculate correct over total answers and return prediction if the |
| 269 | `output_predictions` is true.""" |
| 270 | model.eval() |
| 271 | local_predictions = {} |
| 272 | print_rank_0("Distributed store created") |
| 273 | with torch.no_grad(): |
| 274 | # For all the batches in the dataset. |
| 275 | for idx, data in enumerate(dataloader): |
| 276 | tokens, attention_mask, position_ids = process_batch(data, args) |
| 277 | batch_size = tokens.size(0) |
| 278 | beam_scorer = BeamSearchScorer( |
| 279 | batch_size=batch_size, |
| 280 | max_length=args.out_seq_length, |
| 281 | num_beams=args.num_beams, |
| 282 | device=tokens.device, |
| 283 | length_penalty=args.length_penalty, |
| 284 | do_early_stopping=False, |
| 285 | ) |
| 286 | beam_scores = torch.zeros((batch_size, args.num_beams), dtype=torch.float, device=tokens.device) |
| 287 | beam_scores[:, 1:] = -1e9 |
| 288 | beam_scores = beam_scores.view((batch_size * args.num_beams,)) |
| 289 | # Run the model forward. |
| 290 | counter = 0 |
| 291 | context_length = tokens.size(1) |
| 292 | while counter < args.tgt_seq_length: |
| 293 | if counter == 0: |
| 294 | next_token_logits, *mems = model(tokens, position_ids, attention_mask, return_memory=True) |
| 295 | seq_length = next_token_logits.size(1) |
| 296 | next_token_logits = next_token_logits[:, -1] |
| 297 | next_token_logits = next_token_logits.unsqueeze(1).repeat(1, args.num_beams, 1).view( |
| 298 | batch_size * args.num_beams, -1) |
| 299 | mems = [mem.unsqueeze(1).repeat(1, args.num_beams, 1, 1).view(batch_size * args.num_beams, |
| 300 | seq_length, -1) for mem in mems] |
| 301 | position_ids = tokens.new_ones(batch_size, args.num_beams, 2, 1) |
| 302 | for i, text in enumerate(tokens.tolist()): |
| 303 | mask_pos = text.index(self.mask_token) |
| 304 | position_ids[i, :, 0] = mask_pos |
| 305 | position_ids = position_ids.reshape(batch_size * args.num_beams, 2, 1) |
| 306 | tokens = tokens.new_zeros(batch_size * args.num_beams, 0) |
| 307 | else: |
| 308 | if not args.no_block_position: |
| 309 | position_ids[:, 1] = counter + 1 |
| 310 | last_token = tokens[:, -1:] |
| 311 | if self.mask_pad_token: |
| 312 | cur_attention_mask = attention_mask[:, :, -1:, :].unsqueeze(1).expand(-1, args.num_beams, -1, |
| 313 | -1, -1).reshape( |
| 314 | batch_size * args.num_beams, 1, 1, context_length) |
| 315 | cur_attention_mask = torch.cat( |
| 316 | (cur_attention_mask, attention_mask.new_ones((batch_size * args.num_beams, 1, 1, counter))), |
| 317 | dim=-1) |
| 318 | else: |
| 319 | cur_attention_mask = tokens.new_zeros([batch_size * args.num_beams]) |
| 320 | next_token_logits, *mems = model(last_token, position_ids, cur_attention_mask, *mems, |
| 321 | return_memory=True) |
| 322 | next_token_logits = next_token_logits[:, -1] |
| 323 | next_token_logits = top_k_logits(next_token_logits, top_k=args.top_k, top_p=args.top_p) |
| 324 | next_token_scores = F.log_softmax(next_token_logits, dim=-1) |
nothing calls this directly
no test coverage detected