(model, documents, request)
| 14 | ], |
| 15 | ) |
| 16 | def test_merge(model, documents, request): |
| 17 | topic_model = copy.deepcopy(request.getfixturevalue(model)) |
| 18 | nr_topics = len(set(topic_model.topics_)) |
| 19 | |
| 20 | topics_to_merge = [1, 2] |
| 21 | topic_model.merge_topics(documents, topics_to_merge) |
| 22 | mappings = topic_model.topic_mapper_.get_mappings(list(topic_model.hdbscan_model.labels_)) |
| 23 | mapped_labels = [mappings[label] for label in topic_model.hdbscan_model.labels_] |
| 24 | |
| 25 | assert nr_topics == len(set(topic_model.topics_)) + 1 |
| 26 | assert topic_model.get_topic_info().Count.sum() == len(documents) |
| 27 | if model == "online_topic_model": |
| 28 | assert mapped_labels == topic_model.topics_[950:] |
| 29 | else: |
| 30 | assert mapped_labels == topic_model.topics_ |
| 31 | |
| 32 | topics_to_merge = [1, 2] |
| 33 | topic_model.merge_topics(documents, topics_to_merge) |
| 34 | mappings = topic_model.topic_mapper_.get_mappings(list(topic_model.hdbscan_model.labels_)) |
| 35 | mapped_labels = [mappings[label] for label in topic_model.hdbscan_model.labels_] |
| 36 | |
| 37 | assert nr_topics == len(set(topic_model.topics_)) + 2 |
| 38 | assert topic_model.get_topic_info().Count.sum() == len(documents) |
| 39 | if model == "online_topic_model": |
| 40 | assert mapped_labels == topic_model.topics_[950:] |
| 41 | else: |
| 42 | assert mapped_labels == topic_model.topics_ |
nothing calls this directly
no test coverage detected