()
| 34 | |
| 35 | |
| 36 | def test_word_swap_change_location_consistent(): |
| 37 | from flair.data import Sentence |
| 38 | from flair.models import SequenceTagger |
| 39 | |
| 40 | from textattack.augmentation import Augmenter |
| 41 | from textattack.transformations.word_swaps import WordSwapChangeLocation |
| 42 | |
| 43 | augmenter = Augmenter(transformation=WordSwapChangeLocation(consistent=True)) |
| 44 | s = "I am in New York. I love living in New York." |
| 45 | s_augmented = augmenter.augment(s) |
| 46 | augmented_text = Sentence(s_augmented[0]) |
| 47 | tagger = SequenceTagger.load("flair/ner-english") |
| 48 | original_text = Sentence(s) |
| 49 | tagger.predict(original_text) |
| 50 | tagger.predict(augmented_text) |
| 51 | |
| 52 | entity_original = [] |
| 53 | entity_augmented = [] |
| 54 | |
| 55 | for entity in original_text.get_spans("ner"): |
| 56 | entity_original.append(entity.tag) |
| 57 | for entity in augmented_text.get_spans("ner"): |
| 58 | entity_augmented.append(entity.tag) |
| 59 | |
| 60 | assert entity_original == entity_augmented |
| 61 | assert s_augmented[0].count("New York") == 0 |
| 62 | |
| 63 | |
| 64 | def test_word_swap_change_name(): |
nothing calls this directly
no test coverage detected