(self, input_ids)
| 64 | self.timeout = timeout |
| 65 | |
| 66 | def apply_delay_pattern_mask(self, input_ids): |
| 67 | # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler) |
| 68 | _, delay_pattern_mask = self.decoder.build_delay_pattern_mask( |
| 69 | input_ids[:, :1], |
| 70 | bos_token_id=self.generation_config.bos_token_id, |
| 71 | pad_token_id=self.generation_config.decoder_start_token_id, |
| 72 | max_length=input_ids.shape[-1], |
| 73 | ) |
| 74 | # apply the pattern mask to the input ids |
| 75 | input_ids = self.decoder.apply_delay_pattern_mask(input_ids, delay_pattern_mask) |
| 76 | |
| 77 | # revert the pattern delay mask by filtering the pad token id |
| 78 | mask = (delay_pattern_mask != self.generation_config.bos_token_id) & (delay_pattern_mask != self.generation_config.pad_token_id) |
| 79 | input_ids = input_ids[mask].reshape(1, self.decoder.num_codebooks, -1) |
| 80 | |
| 81 | if self.use_4dim_audio_codes: |
| 82 | # append the frame dimension back to the audio codes |
| 83 | input_ids = input_ids[None, ...] |
| 84 | |
| 85 | # send the input_ids to the correct device |
| 86 | input_ids = input_ids.to(self.audio_encoder.device) |
| 87 | |
| 88 | decode_sequentially = ( |
| 89 | self.generation_config.bos_token_id in input_ids |
| 90 | or self.generation_config.pad_token_id in input_ids |
| 91 | or self.generation_config.eos_token_id in input_ids |
| 92 | ) |
| 93 | if not decode_sequentially: |
| 94 | sample = self.audio_encoder.decode( |
| 95 | audio_codes=input_ids, |
| 96 | **self.audio_kwargs, |
| 97 | ).audio_values |
| 98 | output_values = sample if sample.ndim == 3 else sample.unsqueeze(0) |
| 99 | else: |
| 100 | sample = input_ids[:, 0] if self.use_4dim_audio_codes else input_ids[0] |
| 101 | sample_mask = ((sample >= self.audio_encoder.config.codebook_size).sum(dim=(0, 1)) == 0) if self.use_4dim_audio_codes else ((sample >= self.audio_encoder.config.codebook_size).sum(dim=0) == 0) |
| 102 | sample = sample[:, :, sample_mask] if self.use_4dim_audio_codes else sample[:, sample_mask] |
| 103 | sample = self.audio_encoder.decode(audio_codes=sample[None, ...], **self.audio_kwargs).audio_values |
| 104 | output_values = sample if sample.ndim == 3 else sample.unsqueeze(0) |
| 105 | |
| 106 | audio_values = output_values[0, 0] |
| 107 | return audio_values.cpu().float().numpy() |
| 108 | |
| 109 | def put(self, value): |
| 110 | batch_size = value.shape[0] // self.decoder.num_codebooks |
no test coverage detected