(self, batch)
| 76 | return embed.detach() |
| 77 | |
| 78 | def _get_text_embed(self, batch): |
| 79 | double_batch = False |
| 80 | if len(batch) == 1: |
| 81 | batch = batch * 2 |
| 82 | double_batch = True |
| 83 | with torch.no_grad(): |
| 84 | # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode |
| 85 | text_data = self.tokenizer(batch) |
| 86 | embed = self.model.get_text_embedding(text_data) |
| 87 | if double_batch: |
| 88 | embed = embed[0].unsqueeze(0) |
| 89 | |
| 90 | return embed.detach() |
| 91 | |
| 92 | |
| 93 | def get_query_embed(self, modality, audio=None, text=None, use_text_ratio=0.5, device=None): |
no test coverage detected