(model, documents, request)
| 16 | ], |
| 17 | ) |
| 18 | def test_update_topics(model, documents, request): |
| 19 | topic_model = copy.deepcopy(request.getfixturevalue(model)) |
| 20 | old_ctfidf = topic_model.c_tf_idf_ |
| 21 | old_topics = topic_model.topics_ |
| 22 | |
| 23 | topic_model.update_topics(documents, n_gram_range=(1, 3)) |
| 24 | |
| 25 | assert old_ctfidf.shape[1] < topic_model.c_tf_idf_.shape[1] |
| 26 | assert old_topics == topic_model.topics_ |
| 27 | |
| 28 | updated_topics = [topic if topic != 1 else 0 for topic in old_topics] |
| 29 | topic_model.update_topics(documents, topics=updated_topics, n_gram_range=(1, 3)) |
| 30 | |
| 31 | assert len(set(old_topics)) - 1 == len(set(topic_model.topics_)) |
| 32 | |
| 33 | old_topics = topic_model.topics_ |
| 34 | updated_topics = [topic if topic != 2 else 0 for topic in old_topics] |
| 35 | topic_model.update_topics(documents, topics=updated_topics, n_gram_range=(1, 3)) |
| 36 | |
| 37 | assert len(set(old_topics)) - 1 == len(set(topic_model.topics_)) |
| 38 | |
| 39 | |
| 40 | @pytest.mark.parametrize( |
nothing calls this directly
no test coverage detected