Align phoneme recognition predictions to known transcription.
(
transcript: Iterable[SingleSegment],
model: torch.nn.Module,
align_model_metadata: dict,
audio: Union[str, np.ndarray, torch.Tensor],
device: str,
interpolate_method: str = "nearest",
return_char_alignments: bool = False,
print_progress: bool = False,
combined_progress: bool = False,
progress_callback: ProgressCallback = None,
)
| 115 | |
| 116 | |
| 117 | def align( |
| 118 | transcript: Iterable[SingleSegment], |
| 119 | model: torch.nn.Module, |
| 120 | align_model_metadata: dict, |
| 121 | audio: Union[str, np.ndarray, torch.Tensor], |
| 122 | device: str, |
| 123 | interpolate_method: str = "nearest", |
| 124 | return_char_alignments: bool = False, |
| 125 | print_progress: bool = False, |
| 126 | combined_progress: bool = False, |
| 127 | progress_callback: ProgressCallback = None, |
| 128 | ) -> AlignedTranscriptionResult: |
| 129 | """ |
| 130 | Align phoneme recognition predictions to known transcription. |
| 131 | """ |
| 132 | |
| 133 | if not torch.is_tensor(audio): |
| 134 | if isinstance(audio, str): |
| 135 | audio = load_audio(audio) |
| 136 | audio = torch.from_numpy(audio) |
| 137 | if len(audio.shape) == 1: |
| 138 | audio = audio.unsqueeze(0) |
| 139 | |
| 140 | MAX_DURATION = audio.shape[1] / SAMPLE_RATE |
| 141 | |
| 142 | model_dictionary = align_model_metadata["dictionary"] |
| 143 | model_lang = align_model_metadata["language"] |
| 144 | model_type = align_model_metadata["type"] |
| 145 | |
| 146 | # 1. Preprocess to keep only characters in dictionary |
| 147 | total_segments = len(transcript) |
| 148 | # Store temporary processing values |
| 149 | segment_data: dict[int, SegmentData] = {} |
| 150 | for sdx, segment in enumerate(transcript): |
| 151 | # strip spaces at beginning / end, but keep track of the amount. |
| 152 | if print_progress: |
| 153 | base_progress = ((sdx + 1) / total_segments) * 100 |
| 154 | percent_complete = (50 + base_progress / 2) if combined_progress else base_progress |
| 155 | print(f"Progress: {percent_complete:.2f}%...") |
| 156 | |
| 157 | num_leading = len(segment["text"]) - len(segment["text"].lstrip()) |
| 158 | num_trailing = len(segment["text"]) - len(segment["text"].rstrip()) |
| 159 | text = segment["text"] |
| 160 | |
| 161 | # split into words |
| 162 | if model_lang not in LANGUAGES_WITHOUT_SPACES: |
| 163 | per_word = text.split(" ") |
| 164 | else: |
| 165 | per_word = text |
| 166 | |
| 167 | clean_char, clean_cdx = [], [] |
| 168 | for cdx, char in enumerate(text): |
| 169 | char_ = char.lower() |
| 170 | # wav2vec2 models use "|" character to represent spaces |
| 171 | if model_lang not in LANGUAGES_WITHOUT_SPACES: |
| 172 | char_ = char_.replace(" ", "|") |
| 173 | |
| 174 | # ignore whitespace at beginning and end of transcript |
searching dependent graphs…