()
| 8 | |
| 9 | |
| 10 | def test_word_swap_change_location(): |
| 11 | from flair.data import Sentence |
| 12 | from flair.models import SequenceTagger |
| 13 | |
| 14 | from textattack.augmentation import Augmenter |
| 15 | from textattack.transformations.word_swaps import WordSwapChangeLocation |
| 16 | |
| 17 | augmenter = Augmenter(transformation=WordSwapChangeLocation()) |
| 18 | s = "I am in Dallas." |
| 19 | s_augmented = augmenter.augment(s) |
| 20 | augmented_text = Sentence(s_augmented[0]) |
| 21 | tagger = SequenceTagger.load("flair/ner-english") |
| 22 | original_text = Sentence(s) |
| 23 | tagger.predict(original_text) |
| 24 | tagger.predict(augmented_text) |
| 25 | |
| 26 | entity_original = [] |
| 27 | entity_augmented = [] |
| 28 | |
| 29 | for entity in original_text.get_spans("ner"): |
| 30 | entity_original.append(entity.tag) |
| 31 | for entity in augmented_text.get_spans("ner"): |
| 32 | entity_augmented.append(entity.tag) |
| 33 | assert entity_original == entity_augmented |
| 34 | |
| 35 | |
| 36 | def test_word_swap_change_location_consistent(): |
nothing calls this directly
no test coverage detected