MCPcopy
hub / github.com/huggingface/parler-tts / apply_delay_pattern_mask

Function apply_delay_pattern_mask

parler_tts/modeling_parler_tts.py:205–211  ·  view source on GitHub ↗

Apply a delay pattern mask to the decoder input ids, only preserving predictions where the mask is set to -1, and otherwise setting to the value detailed in the mask.

(input_ids, decoder_pad_token_mask)

Source from the content-addressed store, hash-verified

203 per_codebook_losses: Optional[List[torch.FloatTensor]] = None
204
205def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
206 """Apply a delay pattern mask to the decoder input ids, only preserving predictions where
207 the mask is set to -1, and otherwise setting to the value detailed in the mask."""
208 seq_len = input_ids.shape[-1]
209 decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len]
210 input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask)
211 return input_ids
212
213
214def build_delay_pattern_mask(

Callers 1

Calls

no outgoing calls

Tested by

no test coverage detected