Tests the entire pipeline in one go. This serves as a sanity check to see if the default settings result in a good separation of topics. NOTE: This does not cover all cases but merely combines it all together
(model, documents, request)
| 30 | ], |
| 31 | ) |
| 32 | def test_full_model(model, documents, request): |
| 33 | """Tests the entire pipeline in one go. This serves as a sanity check to see if the default |
| 34 | settings result in a good separation of topics. |
| 35 | |
| 36 | NOTE: This does not cover all cases but merely combines it all together |
| 37 | """ |
| 38 | topic_model = copy.deepcopy(request.getfixturevalue(model)) |
| 39 | if model == "base_topic_model": |
| 40 | topic_model.save( |
| 41 | "model_dir", |
| 42 | serialization="pytorch", |
| 43 | save_ctfidf=True, |
| 44 | save_embedding_model="sentence-transformers/all-MiniLM-L6-v2", |
| 45 | ) |
| 46 | topic_model = BERTopic.load("model_dir") |
| 47 | |
| 48 | if model == "cuml_base_topic_model": |
| 49 | assert "cuml" in str(type(topic_model.umap_model)).lower() |
| 50 | assert "cuml" in str(type(topic_model.hdbscan_model)).lower() |
| 51 | |
| 52 | topics = topic_model.topics_ |
| 53 | |
| 54 | for topic in set(topics): |
| 55 | words = topic_model.get_topic(topic)[:10] |
| 56 | assert len(words) == 10 |
| 57 | |
| 58 | for topic in topic_model.get_topic_freq().Topic: |
| 59 | words = topic_model.get_topic(topic)[:10] |
| 60 | assert len(words) == 10 |
| 61 | |
| 62 | assert len(topic_model.get_topic_freq()) > 2 |
| 63 | assert len(topic_model.get_topics()) == len(topic_model.get_topic_freq()) |
| 64 | |
| 65 | # Test extraction of document info |
| 66 | document_info = topic_model.get_document_info(documents) |
| 67 | assert len(document_info) == len(documents) |
| 68 | |
| 69 | # Test transform |
| 70 | doc = "This is a new document to predict." |
| 71 | topics_test, _probs_test = topic_model.transform([doc, doc]) |
| 72 | |
| 73 | assert len(topics_test) == 2 |
| 74 | |
| 75 | # Test zero-shot topic modeling |
| 76 | if topic_model._is_zeroshot(): |
| 77 | if topic_model._outliers: |
| 78 | assert set(topic_model.topic_labels_.keys()) == set(range(-1, len(topic_model.topic_labels_) - 1)) |
| 79 | else: |
| 80 | assert set(topic_model.topic_labels_.keys()) == set(range(len(topic_model.topic_labels_))) |
| 81 | |
| 82 | # Test topics over time |
| 83 | timestamps = [i % 10 for i in range(len(documents))] |
| 84 | topics_over_time = topic_model.topics_over_time(documents, timestamps) |
| 85 | |
| 86 | assert topics_over_time.Frequency.sum() == len(documents) |
| 87 | assert len(topics_over_time.Topic.unique()) == len(set(topics)) |
| 88 | |
| 89 | # Test hierarchical topics |
nothing calls this directly
no test coverage detected