MCPcopy
hub / github.com/QData/TextAttack / _attack_parallel

Method _attack_parallel

textattack/attacker.py:230–403  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

228 print()
229
230 def _attack_parallel(self):
231 pytorch_multiprocessing_workaround()
232
233 if self._checkpoint:
234 num_remaining_attacks = self._checkpoint.num_remaining_attacks
235 worklist = self._checkpoint.worklist
236 worklist_candidates = self._checkpoint.worklist_candidates
237 logger.info(
238 f"Recovered from checkpoint previously saved at {self._checkpoint.datetime}."
239 )
240 else:
241 if self.attack_args.num_successful_examples:
242 num_remaining_attacks = self.attack_args.num_successful_examples
243 # We make `worklist` deque (linked-list) for easy pop and append.
244 # Candidates are other samples we can attack if we need more samples.
245 worklist, worklist_candidates = self._get_worklist(
246 self.attack_args.num_examples_offset,
247 len(self.dataset),
248 self.attack_args.num_successful_examples,
249 self.attack_args.shuffle,
250 )
251 else:
252 num_remaining_attacks = self.attack_args.num_examples
253 # We make `worklist` deque (linked-list) for easy pop and append.
254 # Candidates are other samples we can attack if we need more samples.
255 worklist, worklist_candidates = self._get_worklist(
256 self.attack_args.num_examples_offset,
257 len(self.dataset),
258 self.attack_args.num_examples,
259 self.attack_args.shuffle,
260 )
261
262 in_queue = torch.multiprocessing.Queue()
263 out_queue = torch.multiprocessing.Queue()
264 for i in worklist:
265 try:
266 example, ground_truth_output = self.dataset[i]
267 example = textattack.shared.AttackedText(example)
268 if self.dataset.label_names is not None:
269 example.attack_attrs["label_names"] = self.dataset.label_names
270 in_queue.put((i, example, ground_truth_output))
271 except IndexError:
272 raise IndexError(
273 f"Tried to access element at {i} in dataset of size {len(self.dataset)}."
274 )
275
276 # We reserve the first GPU for coordinating workers.
277 num_gpus = torch.cuda.device_count()
278 num_workers = self.attack_args.num_workers_per_device * num_gpus
279 logger.info(f"Running {num_workers} worker(s) on {num_gpus} GPU(s).")
280
281 # Lock for synchronization
282 lock = mp.Lock()
283
284 # We move Attacker (and its components) to CPU b/c we don't want models using wrong GPU in worker processes.
285 self.attack.cpu_()
286 torch.cuda.empty_cache()
287

Callers 1

attack_datasetMethod · 0.95

Calls 9

_get_worklistMethod · 0.95
saveMethod · 0.95
cpu_Method · 0.80
log_resultMethod · 0.80
enable_stdoutMethod · 0.80
log_summaryMethod · 0.80
closeMethod · 0.45
flushMethod · 0.45

Tested by

no test coverage detected