MCPcopy
hub / github.com/QData/TextAttack / test_use

Function test_use

tests/test_metric_api.py:27–57  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

25
26
27def test_use():
28 import transformers
29
30 from textattack import AttackArgs, Attacker
31 from textattack.attack_recipes import DeepWordBugGao2018
32 from textattack.datasets import HuggingFaceDataset
33 from textattack.metrics.quality_metrics import MeteorMetric
34 from textattack.models.wrappers import HuggingFaceModelWrapper
35
36 model = transformers.AutoModelForSequenceClassification.from_pretrained(
37 "distilbert-base-uncased-finetuned-sst-2-english"
38 )
39 tokenizer = transformers.AutoTokenizer.from_pretrained(
40 "distilbert-base-uncased-finetuned-sst-2-english"
41 )
42 model_wrapper = HuggingFaceModelWrapper(model, tokenizer)
43 attack = DeepWordBugGao2018.build(model_wrapper)
44 dataset = HuggingFaceDataset("glue", "sst2", split="train")
45 attack_args = AttackArgs(
46 num_examples=1,
47 log_to_csv="log.csv",
48 checkpoint_interval=5,
49 checkpoint_dir="checkpoints",
50 disable_stdout=True,
51 )
52 attacker = Attacker(attack, dataset, attack_args)
53 results = attacker.attack_dataset()
54
55 usem = MeteorMetric().calculate(results)
56
57 assert usem["avg_attack_meteor_score"] == 0.71
58
59
60def test_metric_recipe():

Callers

nothing calls this directly

Calls 9

attack_datasetMethod · 0.95
HuggingFaceDatasetClass · 0.90
AttackArgsClass · 0.90
AttackerClass · 0.90
MeteorMetricClass · 0.90
from_pretrainedMethod · 0.45
buildMethod · 0.45
calculateMethod · 0.45

Tested by

no test coverage detected