(self, texts)
| 77 | self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval() |
| 78 | |
| 79 | def get_text_embeddings(self, texts): |
| 80 | texts = [self.text_preprocessing(text) for text in texts] |
| 81 | |
| 82 | text_tokens_and_mask = self.tokenizer( |
| 83 | texts, |
| 84 | max_length=77, |
| 85 | padding='max_length', |
| 86 | truncation=True, |
| 87 | return_attention_mask=True, |
| 88 | add_special_tokens=True, |
| 89 | return_tensors='pt' |
| 90 | ) |
| 91 | text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids'] |
| 92 | text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask'] |
| 93 | |
| 94 | with torch.no_grad(): |
| 95 | text_encoder_embs = self.model( |
| 96 | input_ids=text_tokens_and_mask['input_ids'].to(self.device), |
| 97 | attention_mask=text_tokens_and_mask['attention_mask'].to(self.device), |
| 98 | )['last_hidden_state'].detach() |
| 99 | |
| 100 | return text_encoder_embs |
| 101 | |
| 102 | def text_preprocessing(self, text): |
| 103 | if self.use_text_preprocessing: |
no test coverage detected