After having fit a model, use transform to predict new instances. Arguments: documents: A single document or a list of documents to predict on embeddings: Pre-trained document embeddings. These can be used instead of the sentence-transformer m
(
self,
documents: Union[str, List[str]],
embeddings: np.ndarray = None,
images: List[str] | None = None,
)
| 543 | return predictions, self.probabilities_ |
| 544 | |
| 545 | def transform( |
| 546 | self, |
| 547 | documents: Union[str, List[str]], |
| 548 | embeddings: np.ndarray = None, |
| 549 | images: List[str] | None = None, |
| 550 | ) -> Tuple[List[int], np.ndarray]: |
| 551 | """After having fit a model, use transform to predict new instances. |
| 552 | |
| 553 | Arguments: |
| 554 | documents: A single document or a list of documents to predict on |
| 555 | embeddings: Pre-trained document embeddings. These can be used |
| 556 | instead of the sentence-transformer model. |
| 557 | images: A list of paths to the images to predict on or the images themselves |
| 558 | |
| 559 | Returns: |
| 560 | predictions: Topic predictions for each documents |
| 561 | probabilities: The topic probability distribution which is returned by default. |
| 562 | If `calculate_probabilities` in BERTopic is set to False, then the |
| 563 | probabilities are not calculated to speed up computation and |
| 564 | decrease memory usage. |
| 565 | |
| 566 | Examples: |
| 567 | ```python |
| 568 | from bertopic import BERTopic |
| 569 | from sklearn.datasets import fetch_20newsgroups |
| 570 | |
| 571 | docs = fetch_20newsgroups(subset='all')['data'] |
| 572 | topic_model = BERTopic().fit(docs) |
| 573 | topics, probs = topic_model.transform(docs) |
| 574 | ``` |
| 575 | |
| 576 | If you want to use your own embeddings: |
| 577 | |
| 578 | ```python |
| 579 | from bertopic import BERTopic |
| 580 | from sklearn.datasets import fetch_20newsgroups |
| 581 | from sentence_transformers import SentenceTransformer |
| 582 | |
| 583 | # Create embeddings |
| 584 | docs = fetch_20newsgroups(subset='all')['data'] |
| 585 | sentence_model = SentenceTransformer("all-MiniLM-L6-v2") |
| 586 | embeddings = sentence_model.encode(docs, show_progress_bar=True) |
| 587 | |
| 588 | # Create topic model |
| 589 | topic_model = BERTopic().fit(docs, embeddings) |
| 590 | topics, probs = topic_model.transform(docs, embeddings) |
| 591 | ``` |
| 592 | """ |
| 593 | check_is_fitted(self) |
| 594 | check_embeddings_shape(embeddings, documents) |
| 595 | |
| 596 | if isinstance(documents, str) or documents is None: |
| 597 | documents = [documents] |
| 598 | |
| 599 | if embeddings is None: |
| 600 | embeddings = self._extract_embeddings(documents, images=images, method="document", verbose=self.verbose) |
| 601 | |
| 602 | # Check if an embedding model was found |