Run inference on input data. Args: data_in: Input data (audio samples, file paths, or text). data_lengths: Lengths of each input sample in the batch. key: Sample identifiers. tokenizer: Tokenizer instance for text e
(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
)
| 448 | self.beam_search = beam_search |
| 449 | |
| 450 | def inference( |
| 451 | self, |
| 452 | data_in, |
| 453 | data_lengths=None, |
| 454 | key: list = None, |
| 455 | tokenizer=None, |
| 456 | frontend=None, |
| 457 | **kwargs, |
| 458 | ): |
| 459 | |
| 460 | """Run inference on input data. |
| 461 | |
| 462 | Args: |
| 463 | data_in: Input data (audio samples, file paths, or text). |
| 464 | data_lengths: Lengths of each input sample in the batch. |
| 465 | key: Sample identifiers. |
| 466 | tokenizer: Tokenizer instance for text encoding/decoding. |
| 467 | frontend: Audio frontend for feature extraction. |
| 468 | **kwargs: Additional keyword arguments. |
| 469 | """ |
| 470 | if kwargs.get("batch_size", 1) > 1: |
| 471 | raise NotImplementedError("batch decoding is not implemented") |
| 472 | |
| 473 | # init beamsearch |
| 474 | if self.beam_search is None: |
| 475 | logging.info("enable beam_search") |
| 476 | self.init_beam_search(**kwargs) |
| 477 | self.nbest = kwargs.get("nbest", 1) |
| 478 | |
| 479 | meta_data = {} |
| 480 | if ( |
| 481 | isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank" |
| 482 | ): # fbank |
| 483 | speech, speech_lengths = data_in, data_lengths |
| 484 | if len(speech.shape) < 3: |
| 485 | speech = speech[None, :, :] |
| 486 | if speech_lengths is None: |
| 487 | speech_lengths = speech.shape[1] |
| 488 | else: |
| 489 | # extract fbank feats |
| 490 | time1 = time.perf_counter() |
| 491 | sample_list = load_audio_text_image_video( |
| 492 | data_in, |
| 493 | fs=frontend.fs, |
| 494 | audio_fs=kwargs.get("fs", 16000), |
| 495 | data_type=kwargs.get("data_type", "sound"), |
| 496 | tokenizer=tokenizer, |
| 497 | ) |
| 498 | time2 = time.perf_counter() |
| 499 | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| 500 | audio_sample_list = sample_list[0] |
| 501 | if len(sample_list) > 1: |
| 502 | ocr_sample_list = sample_list[1] |
| 503 | else: |
| 504 | ocr_sample_list = [[294, 0]] |
| 505 | speech, speech_lengths = extract_fbank( |
| 506 | audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend |
| 507 | ) |
nothing calls this directly
no test coverage detected