MCPcopy Index your code
hub / github.com/THUDM/GLM / finalize

Method finalize

generation_utils.py:287–341  ·  view source on GitHub ↗
(
            self,
            input_ids: torch.LongTensor,
            final_beam_scores: torch.FloatTensor,
            final_beam_tokens: torch.LongTensor,
            final_beam_indices: torch.LongTensor,
            pad_token_id: Optional[int] = None,
            eos_token_id: Optional[int] = None,
            mems=None
    )

Source from the content-addressed store, hash-verified

285 )
286
287 def finalize(
288 self,
289 input_ids: torch.LongTensor,
290 final_beam_scores: torch.FloatTensor,
291 final_beam_tokens: torch.LongTensor,
292 final_beam_indices: torch.LongTensor,
293 pad_token_id: Optional[int] = None,
294 eos_token_id: Optional[int] = None,
295 mems=None
296 ) -> Tuple[torch.LongTensor, List[torch.Tensor]]:
297 batch_size = len(self._beam_hyps)
298
299 # finalize all open beam hypotheses and add to generated hypotheses
300 for batch_idx, beam_hyp in enumerate(self._beam_hyps):
301 if self._done[batch_idx]:
302 continue
303
304 # need to add best num_beams hypotheses to generated hyps
305 for beam_id in range(self.num_beams):
306 batch_beam_idx = batch_idx * self.num_beams + beam_id
307 final_score = final_beam_scores[batch_beam_idx].item()
308 final_tokens = input_ids[batch_beam_idx]
309 beam_hyp.add(final_tokens, final_score, mems=[mem[[batch_beam_idx]] for mem in mems] if mems else None)
310
311 # select the best hypotheses
312 sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
313 best = []
314
315 # retrieve best hypotheses
316 for i, beam_hyp in enumerate(self._beam_hyps):
317 sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
318 for j in range(self.num_beam_hyps_to_keep):
319 score, best_hyp, mems = sorted_hyps.pop()
320 sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
321 best.append((best_hyp, mems, score))
322
323 # prepare for adding eos
324 sent_max_len = sent_lengths.max().item()
325 decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
326 scores = final_beam_scores.new(batch_size * self.num_beam_hyps_to_keep)
327 # shorter batches are padded if needed
328 if sent_lengths.min().item() != sent_lengths.max().item():
329 assert pad_token_id is not None, "`pad_token_id` has to be defined"
330 decoded.fill_(pad_token_id)
331
332 # fill with hypotheses and eos_token_id if the latter fits in
333 mems = []
334 for i, (hypo, mem, score) in enumerate(best):
335 scores[i] = score
336 decoded[i, : sent_lengths[i]] = hypo
337 if sent_lengths[i] < sent_max_len:
338 decoded[i, sent_lengths[i]] = eos_token_id
339 mems.append(mem)
340 mems = [torch.cat([mem[i] for mem in mems], dim=0) for i in range(len(mems[0]))] if mems and mems[0] else None
341 return decoded, mems, scores
342
343
344class BeamHypotheses:

Callers 2

sample_sequenceFunction · 0.95
evaluateMethod · 0.95

Calls 2

appendMethod · 0.80
addMethod · 0.45

Tested by

no test coverage detected