MCPcopy Index your code
hub / github.com/zai-org/CodeGeeX / beam_search

Function beam_search

codegeex/megatron/code_generation_utils.py:661–723  ·  view source on GitHub ↗

Beam search. Note that this function does not support model parallel!

(model, context_tokens, num_beams: int)

Source from the content-addressed store, hash-verified

659
660
661def beam_search(model, context_tokens, num_beams: int):
662 """Beam search.
663
664 Note that this function does not support model parallel!
665 """
666 args = get_args()
667 tokenizer = get_tokenizer()
668
669 assert not isinstance(context_tokens[0], list), "batched beam search not supported"
670
671 initial_beam = Beam(context_tokens, 0.0)
672 context_len = len(context_tokens)
673 org_context_len = context_len
674 finished_beams = []
675
676 # first expansion
677 beams = expand_beams([initial_beam], num_beams, model)
678 context_len += 1
679
680 # print(f"initial beam: {initial_beam}")
681
682 while len(beams) > 0 and context_len < args.seq_length:
683 expanded_beams = expand_beams(beams, num_beams, model)
684 next_beams = []
685 for beam in expanded_beams:
686 if args.beam_warmup:
687 if len(beam.tokens) >= org_context_len + args.beam_warmup_length or beam.tokens[-1] == tokenizer.eod:
688 finished_beams.append(beam)
689 else:
690 next_beams.append(beam)
691 else:
692 if args.evaluation:
693 generated_code = tokenizer.detokenize(beam.tokens[org_context_len:])
694 if is_code_generation_finished(generated_code):
695 finished_beams.append(beam)
696 continue
697 if beam.tokens[-1] == tokenizer.eod:
698 finished_beams.append(beam)
699 else:
700 next_beams.append(beam)
701 # only keep top-k beams
702 next_beams.sort(key=lambda b: b.score, reverse=True)
703 beams = next_beams[:num_beams]
704 context_len += 1
705
706 if len(finished_beams) >= num_beams:
707 # first, only keep top-k beams
708 finished_beams.sort(key=lambda b: b.score, reverse=True)
709 finished_beams = finished_beams[:num_beams]
710 return finished_beams # return finished beams with highest scores
711 # stop if all currently expanding beams has a score lower than the minimal score of finished ones
712 min_score = min([b.score for b in finished_beams])
713 if min_score >= beams[0].score:
714 break
715 else:
716 print(f"we have got enough finished beams, but the minimal score is {min_score}")
717 print(f"and the maximum searching score is {beams[0].score}")
718

Callers 1

sample_sequence_batchFunction · 0.85

Calls 6

get_argsFunction · 0.90
get_tokenizerFunction · 0.90
BeamClass · 0.85
expand_beamsFunction · 0.85
detokenizeMethod · 0.45

Tested by

no test coverage detected