Add a new hypothesis to the list.
(self, hyp: torch.LongTensor, sum_logprobs: float, mems=None)
| 360 | return len(self.beams) |
| 361 | |
| 362 | def add(self, hyp: torch.LongTensor, sum_logprobs: float, mems=None): |
| 363 | """ |
| 364 | Add a new hypothesis to the list. |
| 365 | """ |
| 366 | score = sum_logprobs / (max(hyp.shape[-1], 1) ** self.length_penalty) |
| 367 | if len(self) < self.num_beams or score > self.worst_score: |
| 368 | self.beams.append((score, hyp, mems)) |
| 369 | if len(self) > self.num_beams: |
| 370 | sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)]) |
| 371 | del self.beams[sorted_next_scores[0][1]] |
| 372 | self.worst_score = sorted_next_scores[1][0] |
| 373 | else: |
| 374 | self.worst_score = min(score, self.worst_score) |
| 375 | |
| 376 | def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: |
| 377 | """ |
no test coverage detected