(
self,
input_ids: torch.LongTensor,
final_beam_scores: torch.FloatTensor,
final_beam_tokens: torch.LongTensor,
final_beam_indices: torch.LongTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
mems=None
)
| 285 | ) |
| 286 | |
| 287 | def finalize( |
| 288 | self, |
| 289 | input_ids: torch.LongTensor, |
| 290 | final_beam_scores: torch.FloatTensor, |
| 291 | final_beam_tokens: torch.LongTensor, |
| 292 | final_beam_indices: torch.LongTensor, |
| 293 | pad_token_id: Optional[int] = None, |
| 294 | eos_token_id: Optional[int] = None, |
| 295 | mems=None |
| 296 | ) -> Tuple[torch.LongTensor, List[torch.Tensor]]: |
| 297 | batch_size = len(self._beam_hyps) |
| 298 | |
| 299 | # finalize all open beam hypotheses and add to generated hypotheses |
| 300 | for batch_idx, beam_hyp in enumerate(self._beam_hyps): |
| 301 | if self._done[batch_idx]: |
| 302 | continue |
| 303 | |
| 304 | # need to add best num_beams hypotheses to generated hyps |
| 305 | for beam_id in range(self.num_beams): |
| 306 | batch_beam_idx = batch_idx * self.num_beams + beam_id |
| 307 | final_score = final_beam_scores[batch_beam_idx].item() |
| 308 | final_tokens = input_ids[batch_beam_idx] |
| 309 | beam_hyp.add(final_tokens, final_score, mems=[mem[[batch_beam_idx]] for mem in mems] if mems else None) |
| 310 | |
| 311 | # select the best hypotheses |
| 312 | sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) |
| 313 | best = [] |
| 314 | |
| 315 | # retrieve best hypotheses |
| 316 | for i, beam_hyp in enumerate(self._beam_hyps): |
| 317 | sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) |
| 318 | for j in range(self.num_beam_hyps_to_keep): |
| 319 | score, best_hyp, mems = sorted_hyps.pop() |
| 320 | sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) |
| 321 | best.append((best_hyp, mems, score)) |
| 322 | |
| 323 | # prepare for adding eos |
| 324 | sent_max_len = sent_lengths.max().item() |
| 325 | decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) |
| 326 | scores = final_beam_scores.new(batch_size * self.num_beam_hyps_to_keep) |
| 327 | # shorter batches are padded if needed |
| 328 | if sent_lengths.min().item() != sent_lengths.max().item(): |
| 329 | assert pad_token_id is not None, "`pad_token_id` has to be defined" |
| 330 | decoded.fill_(pad_token_id) |
| 331 | |
| 332 | # fill with hypotheses and eos_token_id if the latter fits in |
| 333 | mems = [] |
| 334 | for i, (hypo, mem, score) in enumerate(best): |
| 335 | scores[i] = score |
| 336 | decoded[i, : sent_lengths[i]] = hypo |
| 337 | if sent_lengths[i] < sent_max_len: |
| 338 | decoded[i, sent_lengths[i]] = eos_token_id |
| 339 | mems.append(mem) |
| 340 | mems = [torch.cat([mem[i] for mem in mems], dim=0) for i in range(len(mems[0]))] if mems and mems[0] else None |
| 341 | return decoded, mems, scores |
| 342 | |
| 343 | |
| 344 | class BeamHypotheses: |
no test coverage detected