MCPcopy
hub / github.com/MaartenGr/BERTopic / test_full_model

Function test_full_model

tests/test_bertopic.py:32–155  ·  view source on GitHub ↗

Tests the entire pipeline in one go. This serves as a sanity check to see if the default settings result in a good separation of topics. NOTE: This does not cover all cases but merely combines it all together

(model, documents, request)

Source from the content-addressed store, hash-verified

30 ],
31)
32def test_full_model(model, documents, request):
33 """Tests the entire pipeline in one go. This serves as a sanity check to see if the default
34 settings result in a good separation of topics.
35
36 NOTE: This does not cover all cases but merely combines it all together
37 """
38 topic_model = copy.deepcopy(request.getfixturevalue(model))
39 if model == "base_topic_model":
40 topic_model.save(
41 "model_dir",
42 serialization="pytorch",
43 save_ctfidf=True,
44 save_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
45 )
46 topic_model = BERTopic.load("model_dir")
47
48 if model == "cuml_base_topic_model":
49 assert "cuml" in str(type(topic_model.umap_model)).lower()
50 assert "cuml" in str(type(topic_model.hdbscan_model)).lower()
51
52 topics = topic_model.topics_
53
54 for topic in set(topics):
55 words = topic_model.get_topic(topic)[:10]
56 assert len(words) == 10
57
58 for topic in topic_model.get_topic_freq().Topic:
59 words = topic_model.get_topic(topic)[:10]
60 assert len(words) == 10
61
62 assert len(topic_model.get_topic_freq()) > 2
63 assert len(topic_model.get_topics()) == len(topic_model.get_topic_freq())
64
65 # Test extraction of document info
66 document_info = topic_model.get_document_info(documents)
67 assert len(document_info) == len(documents)
68
69 # Test transform
70 doc = "This is a new document to predict."
71 topics_test, _probs_test = topic_model.transform([doc, doc])
72
73 assert len(topics_test) == 2
74
75 # Test zero-shot topic modeling
76 if topic_model._is_zeroshot():
77 if topic_model._outliers:
78 assert set(topic_model.topic_labels_.keys()) == set(range(-1, len(topic_model.topic_labels_) - 1))
79 else:
80 assert set(topic_model.topic_labels_.keys()) == set(range(len(topic_model.topic_labels_)))
81
82 # Test topics over time
83 timestamps = [i % 10 for i in range(len(documents))]
84 topics_over_time = topic_model.topics_over_time(documents, timestamps)
85
86 assert topics_over_time.Frequency.sum() == len(documents)
87 assert len(topics_over_time.Topic.unique()) == len(set(topics))
88
89 # Test hierarchical topics

Callers

nothing calls this directly

Calls 15

saveMethod · 0.80
loadMethod · 0.80
get_topicMethod · 0.80
get_topic_freqMethod · 0.80
get_topicsMethod · 0.80
get_document_infoMethod · 0.80
_is_zeroshotMethod · 0.80
topics_over_timeMethod · 0.80
hierarchical_topicsMethod · 0.80
get_topic_treeMethod · 0.80
find_topicsMethod · 0.80
reduce_topicsMethod · 0.80

Tested by

no test coverage detected