MCPcopy
hub / github.com/MaartenGr/BERTopic / test_delete

Function test_delete

tests/test_reduction/test_delete.py:16–59  ·  view source on GitHub ↗
(model, request)

Source from the content-addressed store, hash-verified

14 ],
15)
16def test_delete(model, request):
17 topic_model = copy.deepcopy(request.getfixturevalue(model))
18 nr_topics = len(set(topic_model.topics_))
19 length_documents = len(topic_model.topics_)
20
21 # First deletion
22 topics_to_delete = [1, 2]
23 topic_model.delete_topics(topics_to_delete)
24 mappings = topic_model.topic_mapper_.get_mappings(list(topic_model.hdbscan_model.labels_))
25 mapped_labels = [mappings[label] for label in topic_model.hdbscan_model.labels_]
26
27 if model == "online_topic_model" or model == "kmeans_pca_topic_model":
28 assert nr_topics == len(set(topic_model.topics_)) + 1
29 assert topic_model.get_topic_info().Count.sum() == length_documents
30 else:
31 assert nr_topics == len(set(topic_model.topics_)) + 2
32 assert topic_model.get_topic_info().Count.sum() == length_documents
33
34 if model == "online_topic_model":
35 assert mapped_labels == topic_model.topics_[950:]
36 else:
37 assert mapped_labels == topic_model.topics_
38
39 # Find two existing topics for second deletion
40 remaining_topics = sorted(list(set(topic_model.topics_)))
41 remaining_topics = [t for t in remaining_topics if t != -1] # Exclude outlier topic
42 topics_to_delete = remaining_topics[:2] # Take first two remaining topics
43
44 # Second deletion
45 topic_model.delete_topics(topics_to_delete)
46 mappings = topic_model.topic_mapper_.get_mappings(list(topic_model.hdbscan_model.labels_))
47 mapped_labels = [mappings[label] for label in topic_model.hdbscan_model.labels_]
48
49 if model == "online_topic_model" or model == "kmeans_pca_topic_model":
50 assert nr_topics == len(set(topic_model.topics_)) + 3
51 assert topic_model.get_topic_info().Count.sum() == length_documents
52 else:
53 assert nr_topics == len(set(topic_model.topics_)) + 4
54 assert topic_model.get_topic_info().Count.sum() == length_documents
55
56 if model == "online_topic_model":
57 assert mapped_labels == topic_model.topics_[950:]
58 else:
59 assert mapped_labels == topic_model.topics_

Callers

nothing calls this directly

Calls 3

delete_topicsMethod · 0.80
get_mappingsMethod · 0.80
get_topic_infoMethod · 0.80

Tested by

no test coverage detected