| 67 | """ |
| 68 | |
| 69 | def __init__(self, attack, dataset, attack_args=None): |
| 70 | assert isinstance( |
| 71 | attack, Attack |
| 72 | ), f"`attack` argument must be of type `textattack.Attack`, but got type of `{type(attack)}`." |
| 73 | assert isinstance( |
| 74 | dataset, textattack.datasets.Dataset |
| 75 | ), f"`dataset` must be of type `textattack.datasets.Dataset`, but got type `{type(dataset)}`." |
| 76 | |
| 77 | if attack_args: |
| 78 | assert isinstance( |
| 79 | attack_args, AttackArgs |
| 80 | ), f"`attack_args` must be of type `textattack.AttackArgs`, but got type `{type(attack_args)}`." |
| 81 | else: |
| 82 | attack_args = AttackArgs() |
| 83 | |
| 84 | self.attack = attack |
| 85 | self.dataset = dataset |
| 86 | self.attack_args = attack_args |
| 87 | self.attack_log_manager = None |
| 88 | |
| 89 | # This is to be set if loading from a checkpoint |
| 90 | self._checkpoint = None |
| 91 | |
| 92 | def _get_worklist(self, start, end, num_examples, shuffle): |
| 93 | if end - start < num_examples: |