Generate chunk. Args: speech: Speech audio tensor, shape (batch, time). speech_lengths: Length of each speech sample. key: Sample identifiers. tokenizer: Tokenizer instance for text encoding/decoding.
(
self,
speech,
speech_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
)
| 597 | self.beam_search = beam_search |
| 598 | |
| 599 | def generate_chunk( |
| 600 | self, |
| 601 | speech, |
| 602 | speech_lengths=None, |
| 603 | key: list = None, |
| 604 | tokenizer=None, |
| 605 | frontend=None, |
| 606 | **kwargs, |
| 607 | ): |
| 608 | """Generate chunk. |
| 609 | |
| 610 | Args: |
| 611 | speech: Speech audio tensor, shape (batch, time). |
| 612 | speech_lengths: Length of each speech sample. |
| 613 | key: Sample identifiers. |
| 614 | tokenizer: Tokenizer instance for text encoding/decoding. |
| 615 | frontend: Audio frontend for feature extraction. |
| 616 | **kwargs: Additional keyword arguments. |
| 617 | """ |
| 618 | cache = kwargs.get("cache", {}) |
| 619 | speech = speech.to(device=kwargs["device"]) |
| 620 | speech_lengths = speech_lengths.to(device=kwargs["device"]) |
| 621 | |
| 622 | # Encoder |
| 623 | encoder_out, encoder_out_lens = self.encode_chunk( |
| 624 | speech, speech_lengths, cache=cache, is_final=kwargs.get("is_final", False) |
| 625 | ) |
| 626 | if isinstance(encoder_out, tuple): |
| 627 | encoder_out = encoder_out[0] |
| 628 | if "running_hyps" not in cache: |
| 629 | running_hyps = self.beam_search.init_hyp(encoder_out) |
| 630 | cache["running_hyps"] = running_hyps |
| 631 | |
| 632 | # predictor |
| 633 | predictor_outs = self.calc_predictor_chunk( |
| 634 | encoder_out, |
| 635 | encoder_out_lens, |
| 636 | cache=cache, |
| 637 | is_final=kwargs.get("is_final", False), |
| 638 | ) |
| 639 | pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = ( |
| 640 | predictor_outs[0], |
| 641 | predictor_outs[1], |
| 642 | predictor_outs[2], |
| 643 | predictor_outs[3], |
| 644 | ) |
| 645 | pre_token_length = pre_token_length.round().long() |
| 646 | |
| 647 | if torch.max(pre_token_length) < 1: |
| 648 | return [] |
| 649 | maxlen = minlen = pre_token_length |
| 650 | if kwargs.get("is_final", False): |
| 651 | maxlen += kwargs.get("token_num_relax", 5) |
| 652 | minlen = max(0, minlen - kwargs.get("token_num_relax", 5)) |
| 653 | # c. Passed the encoder result and the beam search |
| 654 | nbest_hyps = self.beam_search( |
| 655 | x=encoder_out[0], |
| 656 | scama_mask=None, |
no test coverage detected