MCPcopy
hub / github.com/QData/TextAttack / Perplexity

Class Perplexity

textattack/metrics/quality_metrics/perplexity.py:16–119  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

14
15
16class Perplexity(Metric):
17 def __init__(self, model_name="gpt2"):
18 self.all_metrics = {}
19 self.original_candidates = []
20 self.successful_candidates = []
21
22 if model_name == "gpt2":
23 from transformers import GPT2LMHeadModel, GPT2Tokenizer
24
25 self.ppl_model = GPT2LMHeadModel.from_pretrained("gpt2")
26 self.ppl_model.to(textattack.shared.utils.device)
27 self.ppl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
28 self.ppl_model.eval()
29 self.max_length = self.ppl_model.config.n_positions
30 else:
31 from transformers import AutoModelForMaskedLM, AutoTokenizer
32
33 self.ppl_model = AutoModelForMaskedLM.from_pretrained(model_name)
34 self.ppl_tokenizer = AutoTokenizer.from_pretrained(model_name)
35 self.ppl_model.to(textattack.shared.utils.device)
36 self.ppl_model.eval()
37 self.max_length = self.ppl_model.config.max_position_embeddings
38
39 self.stride = 512
40
41 def calculate(self, results):
42 """Calculates average Perplexity on all successfull attacks using a
43 pre-trained small GPT-2 model.
44
45 Args:
46 results (``AttackResult`` objects):
47 Attack results for each instance in dataset
48
49 Example::
50
51
52 >> import textattack
53 >> import transformers
54 >> model = transformers.AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
55 >> tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
56 >> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
57 >> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper)
58 >> dataset = textattack.datasets.HuggingFaceDataset("glue", "sst2", split="train")
59 >> attack_args = textattack.AttackArgs(
60 num_examples=1,
61 log_to_csv="log.csv",
62 checkpoint_interval=5,
63 checkpoint_dir="checkpoints",
64 disable_stdout=True
65 )
66 >> attacker = textattack.Attacker(attack, dataset, attack_args)
67 >> results = attacker.attack_dataset()
68 >> ppl = textattack.metrics.quality_metrics.Perplexity().calculate(results)
69 """
70 self.results = results
71 self.original_candidates_ppl = []
72 self.successful_candidates_ppl = []
73

Callers 4

calculateMethod · 0.90
augmentMethod · 0.90
log_summaryMethod · 0.90
test_perplexityFunction · 0.90

Calls

no outgoing calls

Tested by 1

test_perplexityFunction · 0.72