Trainer is training and eval loop for adversarial training. It is designed to work with PyTorch and Transformers models. Args: model_wrapper (:class:`~textattack.models.wrappers.ModelWrapper`): Model wrapper containing both the model and the tokenizer. task_type
| 28 | |
| 29 | |
| 30 | class Trainer: |
| 31 | """Trainer is training and eval loop for adversarial training. |
| 32 | |
| 33 | It is designed to work with PyTorch and Transformers models. |
| 34 | |
| 35 | Args: |
| 36 | model_wrapper (:class:`~textattack.models.wrappers.ModelWrapper`): |
| 37 | Model wrapper containing both the model and the tokenizer. |
| 38 | task_type (:obj:`str`, `optional`, defaults to :obj:`"classification"`): |
| 39 | The task that the model is trained to perform. |
| 40 | Currently, :class:`~textattack.Trainer` supports two tasks: (1) :obj:`"classification"`, (2) :obj:`"regression"`. |
| 41 | attack (:class:`~textattack.Attack`): |
| 42 | :class:`~textattack.Attack` used to generate adversarial examples for training. |
| 43 | train_dataset (:class:`~textattack.datasets.Dataset`): |
| 44 | Dataset for training. |
| 45 | eval_dataset (:class:`~textattack.datasets.Dataset`): |
| 46 | Dataset for evaluation |
| 47 | training_args (:class:`~textattack.TrainingArgs`): |
| 48 | Arguments for training. |
| 49 | |
| 50 | Example:: |
| 51 | |
| 52 | >>> import textattack |
| 53 | >>> import transformers |
| 54 | |
| 55 | >>> model = transformers.AutoModelForSequenceClassification.from_pretrained("bert-base-uncased") |
| 56 | >>> tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased") |
| 57 | >>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) |
| 58 | |
| 59 | >>> # We only use DeepWordBugGao2018 to demonstration purposes. |
| 60 | >>> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper) |
| 61 | >>> train_dataset = textattack.datasets.HuggingFaceDataset("imdb", split="train") |
| 62 | >>> eval_dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test") |
| 63 | |
| 64 | >>> # Train for 3 epochs with 1 initial clean epochs, 1000 adversarial examples per epoch, learning rate of 5e-5, and effective batch size of 32 (8x4). |
| 65 | >>> training_args = textattack.TrainingArgs( |
| 66 | ... num_epochs=3, |
| 67 | ... num_clean_epochs=1, |
| 68 | ... num_train_adv_examples=1000, |
| 69 | ... learning_rate=5e-5, |
| 70 | ... per_device_train_batch_size=8, |
| 71 | ... gradient_accumulation_steps=4, |
| 72 | ... log_to_tb=True, |
| 73 | ... ) |
| 74 | |
| 75 | >>> trainer = textattack.Trainer( |
| 76 | ... model_wrapper, |
| 77 | ... "classification", |
| 78 | ... attack, |
| 79 | ... train_dataset, |
| 80 | ... eval_dataset, |
| 81 | ... training_args |
| 82 | ... ) |
| 83 | >>> trainer.train() |
| 84 | |
| 85 | .. note:: |
| 86 | When using :class:`~textattack.Trainer` with `parallel=True` in :class:`~textattack.TrainingArgs`, |
| 87 | make sure to protect the “entry point” of the program by using :obj:`if __name__ == '__main__':`. |