(model, request)
| 14 | ], |
| 15 | ) |
| 16 | def test_delete(model, request): |
| 17 | topic_model = copy.deepcopy(request.getfixturevalue(model)) |
| 18 | nr_topics = len(set(topic_model.topics_)) |
| 19 | length_documents = len(topic_model.topics_) |
| 20 | |
| 21 | # First deletion |
| 22 | topics_to_delete = [1, 2] |
| 23 | topic_model.delete_topics(topics_to_delete) |
| 24 | mappings = topic_model.topic_mapper_.get_mappings(list(topic_model.hdbscan_model.labels_)) |
| 25 | mapped_labels = [mappings[label] for label in topic_model.hdbscan_model.labels_] |
| 26 | |
| 27 | if model == "online_topic_model" or model == "kmeans_pca_topic_model": |
| 28 | assert nr_topics == len(set(topic_model.topics_)) + 1 |
| 29 | assert topic_model.get_topic_info().Count.sum() == length_documents |
| 30 | else: |
| 31 | assert nr_topics == len(set(topic_model.topics_)) + 2 |
| 32 | assert topic_model.get_topic_info().Count.sum() == length_documents |
| 33 | |
| 34 | if model == "online_topic_model": |
| 35 | assert mapped_labels == topic_model.topics_[950:] |
| 36 | else: |
| 37 | assert mapped_labels == topic_model.topics_ |
| 38 | |
| 39 | # Find two existing topics for second deletion |
| 40 | remaining_topics = sorted(list(set(topic_model.topics_))) |
| 41 | remaining_topics = [t for t in remaining_topics if t != -1] # Exclude outlier topic |
| 42 | topics_to_delete = remaining_topics[:2] # Take first two remaining topics |
| 43 | |
| 44 | # Second deletion |
| 45 | topic_model.delete_topics(topics_to_delete) |
| 46 | mappings = topic_model.topic_mapper_.get_mappings(list(topic_model.hdbscan_model.labels_)) |
| 47 | mapped_labels = [mappings[label] for label in topic_model.hdbscan_model.labels_] |
| 48 | |
| 49 | if model == "online_topic_model" or model == "kmeans_pca_topic_model": |
| 50 | assert nr_topics == len(set(topic_model.topics_)) + 3 |
| 51 | assert topic_model.get_topic_info().Count.sum() == length_documents |
| 52 | else: |
| 53 | assert nr_topics == len(set(topic_model.topics_)) + 4 |
| 54 | assert topic_model.get_topic_info().Count.sum() == length_documents |
| 55 | |
| 56 | if model == "online_topic_model": |
| 57 | assert mapped_labels == topic_model.topics_[950:] |
| 58 | else: |
| 59 | assert mapped_labels == topic_model.topics_ |
nothing calls this directly
no test coverage detected