MCPcopy
hub / github.com/explosion/spaCy / test_resize_same_results

Function test_resize_same_results

spacy/tests/pipeline/test_textcat.py:503–545  ·  view source on GitHub ↗
(name, textcat_config)

Source from the content-addressed store, hash-verified

501)
502# fmt: on
503def test_resize_same_results(name, textcat_config):
504 # Ensure that the resized textcat classifiers still produce the same results for old labels
505 fix_random_seed(0)
506 nlp = English()
507 pipe_config = {"model": textcat_config}
508 textcat = nlp.add_pipe(name, config=pipe_config)
509
510 train_examples = []
511 for text, annotations in TRAIN_DATA_SINGLE_LABEL:
512 train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
513 optimizer = nlp.initialize(get_examples=lambda: train_examples)
514 assert textcat.model.maybe_get_dim("nO") in [2, None]
515
516 for i in range(5):
517 losses = {}
518 nlp.update(train_examples, sgd=optimizer, losses=losses)
519
520 # test the trained model before resizing
521 test_text = "I am happy."
522 doc = nlp(test_text)
523 assert len(doc.cats) == 2
524 pos_pred = doc.cats["POSITIVE"]
525 neg_pred = doc.cats["NEGATIVE"]
526
527 # test the trained model again after resizing
528 textcat.add_label("NEUTRAL")
529 doc = nlp(test_text)
530 assert len(doc.cats) == 3
531 assert doc.cats["POSITIVE"] == pos_pred
532 assert doc.cats["NEGATIVE"] == neg_pred
533 assert doc.cats["NEUTRAL"] <= 1
534
535 for i in range(5):
536 losses = {}
537 nlp.update(train_examples, sgd=optimizer, losses=losses)
538
539 # test the trained model again after training further with new label
540 doc = nlp(test_text)
541 assert len(doc.cats) == 3
542 assert doc.cats["POSITIVE"] != pos_pred
543 assert doc.cats["NEGATIVE"] != neg_pred
544 for cat in doc.cats:
545 assert doc.cats[cat] <= 1
546
547
548def test_error_with_multi_labels():

Callers

nothing calls this directly

Calls 9

EnglishClass · 0.90
add_pipeMethod · 0.80
appendMethod · 0.80
from_dictMethod · 0.80
make_docMethod · 0.80
nlpFunction · 0.70
initializeMethod · 0.45
updateMethod · 0.45
add_labelMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…