Beam search. Note that this function does not support model parallel!
(model, context_tokens, num_beams: int)
| 659 | |
| 660 | |
| 661 | def beam_search(model, context_tokens, num_beams: int): |
| 662 | """Beam search. |
| 663 | |
| 664 | Note that this function does not support model parallel! |
| 665 | """ |
| 666 | args = get_args() |
| 667 | tokenizer = get_tokenizer() |
| 668 | |
| 669 | assert not isinstance(context_tokens[0], list), "batched beam search not supported" |
| 670 | |
| 671 | initial_beam = Beam(context_tokens, 0.0) |
| 672 | context_len = len(context_tokens) |
| 673 | org_context_len = context_len |
| 674 | finished_beams = [] |
| 675 | |
| 676 | # first expansion |
| 677 | beams = expand_beams([initial_beam], num_beams, model) |
| 678 | context_len += 1 |
| 679 | |
| 680 | # print(f"initial beam: {initial_beam}") |
| 681 | |
| 682 | while len(beams) > 0 and context_len < args.seq_length: |
| 683 | expanded_beams = expand_beams(beams, num_beams, model) |
| 684 | next_beams = [] |
| 685 | for beam in expanded_beams: |
| 686 | if args.beam_warmup: |
| 687 | if len(beam.tokens) >= org_context_len + args.beam_warmup_length or beam.tokens[-1] == tokenizer.eod: |
| 688 | finished_beams.append(beam) |
| 689 | else: |
| 690 | next_beams.append(beam) |
| 691 | else: |
| 692 | if args.evaluation: |
| 693 | generated_code = tokenizer.detokenize(beam.tokens[org_context_len:]) |
| 694 | if is_code_generation_finished(generated_code): |
| 695 | finished_beams.append(beam) |
| 696 | continue |
| 697 | if beam.tokens[-1] == tokenizer.eod: |
| 698 | finished_beams.append(beam) |
| 699 | else: |
| 700 | next_beams.append(beam) |
| 701 | # only keep top-k beams |
| 702 | next_beams.sort(key=lambda b: b.score, reverse=True) |
| 703 | beams = next_beams[:num_beams] |
| 704 | context_len += 1 |
| 705 | |
| 706 | if len(finished_beams) >= num_beams: |
| 707 | # first, only keep top-k beams |
| 708 | finished_beams.sort(key=lambda b: b.score, reverse=True) |
| 709 | finished_beams = finished_beams[:num_beams] |
| 710 | return finished_beams # return finished beams with highest scores |
| 711 | # stop if all currently expanding beams has a score lower than the minimal score of finished ones |
| 712 | min_score = min([b.score for b in finished_beams]) |
| 713 | if min_score >= beams[0].score: |
| 714 | break |
| 715 | else: |
| 716 | print(f"we have got enough finished beams, but the minimal score is {min_score}") |
| 717 | print(f"and the maximum searching score is {beams[0].score}") |
| 718 |
no test coverage detected