(self, modality, audio=None, text=None, use_text_ratio=0.5, device=None)
| 91 | |
| 92 | |
| 93 | def get_query_embed(self, modality, audio=None, text=None, use_text_ratio=0.5, device=None): |
| 94 | if modality == 'audio': |
| 95 | embed = self._get_audio_embed(audio) |
| 96 | elif modality == 'text': |
| 97 | embed = self._get_text_embed(text) |
| 98 | elif modality == 'hybird': |
| 99 | if random.random() > use_text_ratio: |
| 100 | embed = self._get_audio_embed(audio) |
| 101 | else: |
| 102 | embed = self._get_text_embed(text) |
| 103 | else: |
| 104 | raise NotImplementedError("Please check flag 'training_modality'.") |
| 105 | |
| 106 | return embed.float() |
| 107 | |
| 108 | def tokenizer(self, text): |
| 109 | result = self.tokenize( |
no test coverage detected