MCPcopy
hub / github.com/QData/TextAttack / USEMetric

Class USEMetric

textattack/metrics/quality_metrics/use.py:14–74  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

12
13
14class USEMetric(Metric):
15 def __init__(self, **kwargs):
16 self.use_obj = UniversalSentenceEncoder()
17 self.use_obj.model = UniversalSentenceEncoder()
18 self.original_candidates = []
19 self.successful_candidates = []
20 self.all_metrics = {}
21
22 def calculate(self, results):
23 """Calculates average USE similarity on all successfull attacks.
24
25 Args:
26 results (``AttackResult`` objects):
27 Attack results for each instance in dataset
28
29 Example::
30
31
32 >> import textattack
33 >> import transformers
34 >> model = transformers.AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
35 >> tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
36 >> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
37 >> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper)
38 >> dataset = textattack.datasets.HuggingFaceDataset("glue", "sst2", split="train")
39 >> attack_args = textattack.AttackArgs(
40 num_examples=1,
41 log_to_csv="log.csv",
42 checkpoint_interval=5,
43 checkpoint_dir="checkpoints",
44 disable_stdout=True
45 )
46 >> attacker = textattack.Attacker(attack, dataset, attack_args)
47 >> results = attacker.attack_dataset()
48 >> usem = textattack.metrics.quality_metrics.USEMetric().calculate(results)
49 """
50
51 self.results = results
52
53 for i, result in enumerate(self.results):
54 if isinstance(result, FailedAttackResult):
55 continue
56 elif isinstance(result, SkippedAttackResult):
57 continue
58 else:
59 self.original_candidates.append(result.original_result.attacked_text)
60 self.successful_candidates.append(result.perturbed_result.attacked_text)
61
62 use_scores = []
63 for c in range(len(self.original_candidates)):
64 use_scores.append(
65 self.use_obj._sim_score(
66 self.original_candidates[c], self.successful_candidates[c]
67 ).item()
68 )
69
70 self.all_metrics["avg_attack_use_score"] = round(
71 sum(use_scores) / len(use_scores), 2

Callers 3

calculateMethod · 0.90
augmentMethod · 0.90
log_summaryMethod · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected