| 14 | |
| 15 | |
| 16 | class Perplexity(Metric): |
| 17 | def __init__(self, model_name="gpt2"): |
| 18 | self.all_metrics = {} |
| 19 | self.original_candidates = [] |
| 20 | self.successful_candidates = [] |
| 21 | |
| 22 | if model_name == "gpt2": |
| 23 | from transformers import GPT2LMHeadModel, GPT2Tokenizer |
| 24 | |
| 25 | self.ppl_model = GPT2LMHeadModel.from_pretrained("gpt2") |
| 26 | self.ppl_model.to(textattack.shared.utils.device) |
| 27 | self.ppl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
| 28 | self.ppl_model.eval() |
| 29 | self.max_length = self.ppl_model.config.n_positions |
| 30 | else: |
| 31 | from transformers import AutoModelForMaskedLM, AutoTokenizer |
| 32 | |
| 33 | self.ppl_model = AutoModelForMaskedLM.from_pretrained(model_name) |
| 34 | self.ppl_tokenizer = AutoTokenizer.from_pretrained(model_name) |
| 35 | self.ppl_model.to(textattack.shared.utils.device) |
| 36 | self.ppl_model.eval() |
| 37 | self.max_length = self.ppl_model.config.max_position_embeddings |
| 38 | |
| 39 | self.stride = 512 |
| 40 | |
| 41 | def calculate(self, results): |
| 42 | """Calculates average Perplexity on all successfull attacks using a |
| 43 | pre-trained small GPT-2 model. |
| 44 | |
| 45 | Args: |
| 46 | results (``AttackResult`` objects): |
| 47 | Attack results for each instance in dataset |
| 48 | |
| 49 | Example:: |
| 50 | |
| 51 | |
| 52 | >> import textattack |
| 53 | >> import transformers |
| 54 | >> model = transformers.AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") |
| 55 | >> tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") |
| 56 | >> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) |
| 57 | >> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper) |
| 58 | >> dataset = textattack.datasets.HuggingFaceDataset("glue", "sst2", split="train") |
| 59 | >> attack_args = textattack.AttackArgs( |
| 60 | num_examples=1, |
| 61 | log_to_csv="log.csv", |
| 62 | checkpoint_interval=5, |
| 63 | checkpoint_dir="checkpoints", |
| 64 | disable_stdout=True |
| 65 | ) |
| 66 | >> attacker = textattack.Attacker(attack, dataset, attack_args) |
| 67 | >> results = attacker.attack_dataset() |
| 68 | >> ppl = textattack.metrics.quality_metrics.Perplexity().calculate(results) |
| 69 | """ |
| 70 | self.results = results |
| 71 | self.original_candidates_ppl = [] |
| 72 | self.successful_candidates_ppl = [] |
| 73 |
no outgoing calls