MCPcopy
hub / github.com/QData/TextAttack / _generate_adversarial_examples

Method _generate_adversarial_examples

textattack/trainer.py:164–251  ·  view source on GitHub ↗

Generate adversarial examples using attacker.

(self, epoch)

Source from the content-addressed store, hash-verified

162 self._global_step = 0
163
164 def _generate_adversarial_examples(self, epoch):
165 """Generate adversarial examples using attacker."""
166 assert (
167 self.attack is not None
168 ), "`attack` is `None` but attempting to generate adversarial examples."
169 base_file_name = f"attack-train-{epoch}"
170 log_file_name = os.path.join(self.training_args.output_dir, base_file_name)
171 logger.info("Attacking model to generate new adversarial training set...")
172
173 if isinstance(self.training_args.num_train_adv_examples, float):
174 num_train_adv_examples = math.ceil(
175 len(self.train_dataset) * self.training_args.num_train_adv_examples
176 )
177 else:
178 num_train_adv_examples = self.training_args.num_train_adv_examples
179
180 # Use Different AttackArgs based on num_train_adv_examples value.
181 # If num_train_adv_examples >= 0 , num_train_adv_examples is
182 # set as number of successful examples.
183 # If num_train_adv_examples == -1 , num_examples is set to -1 to
184 # generate example for all of training data.
185 if num_train_adv_examples >= 0:
186 attack_args = AttackArgs(
187 num_successful_examples=num_train_adv_examples,
188 num_examples_offset=0,
189 query_budget=self.training_args.query_budget_train,
190 shuffle=True,
191 parallel=self.training_args.parallel,
192 num_workers_per_device=self.training_args.attack_num_workers_per_device,
193 disable_stdout=True,
194 silent=True,
195 log_to_txt=log_file_name + ".txt",
196 log_to_csv=log_file_name + ".csv",
197 )
198 elif num_train_adv_examples == -1:
199 # set num_examples when num_train_adv_examples = -1
200 attack_args = AttackArgs(
201 num_examples=num_train_adv_examples,
202 num_examples_offset=0,
203 query_budget=self.training_args.query_budget_train,
204 shuffle=True,
205 parallel=self.training_args.parallel,
206 num_workers_per_device=self.training_args.attack_num_workers_per_device,
207 disable_stdout=True,
208 silent=True,
209 log_to_txt=log_file_name + ".txt",
210 log_to_csv=log_file_name + ".csv",
211 )
212 else:
213 assert False, "num_train_adv_examples is negative and not equal to -1."
214
215 attacker = Attacker(self.attack, self.train_dataset, attack_args=attack_args)
216 results = attacker.attack_dataset()
217
218 attack_types = collections.Counter(r.__class__.__name__ for r in results)
219 total_attacks = (
220 attack_types["SuccessfulAttackResult"] + attack_types["FailedAttackResult"]
221 )

Callers 1

trainMethod · 0.95

Calls 3

attack_datasetMethod · 0.95
AttackArgsClass · 0.85
AttackerClass · 0.85

Tested by

no test coverage detected