MCPcopy Index your code
hub / github.com/QData/TextAttack / Trainer

Class Trainer

textattack/trainer.py:30–1019  ·  view source on GitHub ↗

Trainer is training and eval loop for adversarial training. It is designed to work with PyTorch and Transformers models. Args: model_wrapper (:class:`~textattack.models.wrappers.ModelWrapper`): Model wrapper containing both the model and the tokenizer. task_type

Source from the content-addressed store, hash-verified

28
29
30class Trainer:
31 """Trainer is training and eval loop for adversarial training.
32
33 It is designed to work with PyTorch and Transformers models.
34
35 Args:
36 model_wrapper (:class:`~textattack.models.wrappers.ModelWrapper`):
37 Model wrapper containing both the model and the tokenizer.
38 task_type (:obj:`str`, `optional`, defaults to :obj:`"classification"`):
39 The task that the model is trained to perform.
40 Currently, :class:`~textattack.Trainer` supports two tasks: (1) :obj:`"classification"`, (2) :obj:`"regression"`.
41 attack (:class:`~textattack.Attack`):
42 :class:`~textattack.Attack` used to generate adversarial examples for training.
43 train_dataset (:class:`~textattack.datasets.Dataset`):
44 Dataset for training.
45 eval_dataset (:class:`~textattack.datasets.Dataset`):
46 Dataset for evaluation
47 training_args (:class:`~textattack.TrainingArgs`):
48 Arguments for training.
49
50 Example::
51
52 >>> import textattack
53 >>> import transformers
54
55 >>> model = transformers.AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
56 >>> tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased")
57 >>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
58
59 >>> # We only use DeepWordBugGao2018 to demonstration purposes.
60 >>> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper)
61 >>> train_dataset = textattack.datasets.HuggingFaceDataset("imdb", split="train")
62 >>> eval_dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test")
63
64 >>> # Train for 3 epochs with 1 initial clean epochs, 1000 adversarial examples per epoch, learning rate of 5e-5, and effective batch size of 32 (8x4).
65 >>> training_args = textattack.TrainingArgs(
66 ... num_epochs=3,
67 ... num_clean_epochs=1,
68 ... num_train_adv_examples=1000,
69 ... learning_rate=5e-5,
70 ... per_device_train_batch_size=8,
71 ... gradient_accumulation_steps=4,
72 ... log_to_tb=True,
73 ... )
74
75 >>> trainer = textattack.Trainer(
76 ... model_wrapper,
77 ... "classification",
78 ... attack,
79 ... train_dataset,
80 ... eval_dataset,
81 ... training_args
82 ... )
83 >>> trainer.train()
84
85 .. note::
86 When using :class:`~textattack.Trainer` with `parallel=True` in :class:`~textattack.TrainingArgs`,
87 make sure to protect the “entry point” of the program by using :obj:`if __name__ == '__main__':`.

Callers 1

runMethod · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected