(name, textcat_config)
| 501 | ) |
| 502 | # fmt: on |
| 503 | def test_resize_same_results(name, textcat_config): |
| 504 | # Ensure that the resized textcat classifiers still produce the same results for old labels |
| 505 | fix_random_seed(0) |
| 506 | nlp = English() |
| 507 | pipe_config = {"model": textcat_config} |
| 508 | textcat = nlp.add_pipe(name, config=pipe_config) |
| 509 | |
| 510 | train_examples = [] |
| 511 | for text, annotations in TRAIN_DATA_SINGLE_LABEL: |
| 512 | train_examples.append(Example.from_dict(nlp.make_doc(text), annotations)) |
| 513 | optimizer = nlp.initialize(get_examples=lambda: train_examples) |
| 514 | assert textcat.model.maybe_get_dim("nO") in [2, None] |
| 515 | |
| 516 | for i in range(5): |
| 517 | losses = {} |
| 518 | nlp.update(train_examples, sgd=optimizer, losses=losses) |
| 519 | |
| 520 | # test the trained model before resizing |
| 521 | test_text = "I am happy." |
| 522 | doc = nlp(test_text) |
| 523 | assert len(doc.cats) == 2 |
| 524 | pos_pred = doc.cats["POSITIVE"] |
| 525 | neg_pred = doc.cats["NEGATIVE"] |
| 526 | |
| 527 | # test the trained model again after resizing |
| 528 | textcat.add_label("NEUTRAL") |
| 529 | doc = nlp(test_text) |
| 530 | assert len(doc.cats) == 3 |
| 531 | assert doc.cats["POSITIVE"] == pos_pred |
| 532 | assert doc.cats["NEGATIVE"] == neg_pred |
| 533 | assert doc.cats["NEUTRAL"] <= 1 |
| 534 | |
| 535 | for i in range(5): |
| 536 | losses = {} |
| 537 | nlp.update(train_examples, sgd=optimizer, losses=losses) |
| 538 | |
| 539 | # test the trained model again after training further with new label |
| 540 | doc = nlp(test_text) |
| 541 | assert len(doc.cats) == 3 |
| 542 | assert doc.cats["POSITIVE"] != pos_pred |
| 543 | assert doc.cats["NEGATIVE"] != neg_pred |
| 544 | for cat in doc.cats: |
| 545 | assert doc.cats[cat] <= 1 |
| 546 | |
| 547 | |
| 548 | def test_error_with_multi_labels(): |
nothing calls this directly
no test coverage detected
searching dependent graphs…