MCPcopy
hub / github.com/m-bain/whisperX / align

Function align

whisperx/alignment.py:117–418  ·  view source on GitHub ↗

Align phoneme recognition predictions to known transcription.

(
    transcript: Iterable[SingleSegment],
    model: torch.nn.Module,
    align_model_metadata: dict,
    audio: Union[str, np.ndarray, torch.Tensor],
    device: str,
    interpolate_method: str = "nearest",
    return_char_alignments: bool = False,
    print_progress: bool = False,
    combined_progress: bool = False,
    progress_callback: ProgressCallback = None,
)

Source from the content-addressed store, hash-verified

115
116
117def align(
118 transcript: Iterable[SingleSegment],
119 model: torch.nn.Module,
120 align_model_metadata: dict,
121 audio: Union[str, np.ndarray, torch.Tensor],
122 device: str,
123 interpolate_method: str = "nearest",
124 return_char_alignments: bool = False,
125 print_progress: bool = False,
126 combined_progress: bool = False,
127 progress_callback: ProgressCallback = None,
128) -> AlignedTranscriptionResult:
129 """
130 Align phoneme recognition predictions to known transcription.
131 """
132
133 if not torch.is_tensor(audio):
134 if isinstance(audio, str):
135 audio = load_audio(audio)
136 audio = torch.from_numpy(audio)
137 if len(audio.shape) == 1:
138 audio = audio.unsqueeze(0)
139
140 MAX_DURATION = audio.shape[1] / SAMPLE_RATE
141
142 model_dictionary = align_model_metadata["dictionary"]
143 model_lang = align_model_metadata["language"]
144 model_type = align_model_metadata["type"]
145
146 # 1. Preprocess to keep only characters in dictionary
147 total_segments = len(transcript)
148 # Store temporary processing values
149 segment_data: dict[int, SegmentData] = {}
150 for sdx, segment in enumerate(transcript):
151 # strip spaces at beginning / end, but keep track of the amount.
152 if print_progress:
153 base_progress = ((sdx + 1) / total_segments) * 100
154 percent_complete = (50 + base_progress / 2) if combined_progress else base_progress
155 print(f"Progress: {percent_complete:.2f}%...")
156
157 num_leading = len(segment["text"]) - len(segment["text"].lstrip())
158 num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
159 text = segment["text"]
160
161 # split into words
162 if model_lang not in LANGUAGES_WITHOUT_SPACES:
163 per_word = text.split(" ")
164 else:
165 per_word = text
166
167 clean_char, clean_cdx = [], []
168 for cdx, char in enumerate(text):
169 char_ = char.lower()
170 # wav2vec2 models use "|" character to represent spaces
171 if model_lang not in LANGUAGES_WITHOUT_SPACES:
172 char_ = char_.replace(" ", "|")
173
174 # ignore whitespace at beginning and end of transcript

Callers 2

transcribe_taskFunction · 0.90
_run_alignMethod · 0.90

Calls 5

load_audioFunction · 0.90
interpolate_nansFunction · 0.90
get_trellisFunction · 0.85
backtrackFunction · 0.85
merge_repeatsFunction · 0.85

Tested by 1

_run_alignMethod · 0.72

Used in the wild real call sites across dependent graphs

searching dependent graphs…