Get the text embedding from the model Parameters ---------- data: torch.Tensor a tensor of text embedding Returns ---------- text_embed: torch.Tensor a tensor of text_embeds (N, D)
(self, data)
| 730 | return self.logit_scale_a.exp(), self.logit_scale_t.exp() |
| 731 | |
| 732 | def get_text_embedding(self, data): |
| 733 | """Get the text embedding from the model |
| 734 | |
| 735 | Parameters |
| 736 | ---------- |
| 737 | data: torch.Tensor |
| 738 | a tensor of text embedding |
| 739 | |
| 740 | Returns |
| 741 | ---------- |
| 742 | text_embed: torch.Tensor |
| 743 | a tensor of text_embeds (N, D) |
| 744 | |
| 745 | """ |
| 746 | device = next(self.parameters()).device |
| 747 | for k in data: |
| 748 | data[k] = data[k].to(device) |
| 749 | text_embeds = self.encode_text(data, device=device) |
| 750 | text_embeds = F.normalize(text_embeds, dim=-1) |
| 751 | |
| 752 | return text_embeds |
| 753 | |
| 754 | def get_audio_embedding(self, data): |
| 755 | """Get the audio embedding from the model |
no test coverage detected