Returns the :obj:`torch.utils.data.DataLoader` for training. Args: dataset (:class:`~textattack.datasets.Dataset`): Original training dataset. adv_dataset (:class:`~textattack.datasets.Dataset`): Adversarial examples generated from the
(self, dataset, adv_dataset, batch_size)
| 386 | return optimizer, scheduler |
| 387 | |
| 388 | def get_train_dataloader(self, dataset, adv_dataset, batch_size): |
| 389 | """Returns the :obj:`torch.utils.data.DataLoader` for training. |
| 390 | |
| 391 | Args: |
| 392 | dataset (:class:`~textattack.datasets.Dataset`): |
| 393 | Original training dataset. |
| 394 | adv_dataset (:class:`~textattack.datasets.Dataset`): |
| 395 | Adversarial examples generated from the original training dataset. :obj:`None` if no adversarial attack takes place. |
| 396 | batch_size (:obj:`int`): |
| 397 | Batch size for training. |
| 398 | Returns: |
| 399 | :obj:`torch.utils.data.DataLoader` |
| 400 | """ |
| 401 | |
| 402 | # TODO: Add pairing option where we can pair original examples with adversarial examples. |
| 403 | # Helper functions for collating data |
| 404 | def collate_fn(data): |
| 405 | input_texts = [] |
| 406 | targets = [] |
| 407 | is_adv_sample = [] |
| 408 | for item in data: |
| 409 | if "_example_type" in item[0].keys(): |
| 410 | # Get example type value from OrderedDict and remove it |
| 411 | |
| 412 | adv = item[0].pop("_example_type") |
| 413 | |
| 414 | # with _example_type removed from item[0] OrderedDict |
| 415 | # all other keys should be part of input |
| 416 | _input, label = item |
| 417 | if adv != "adversarial_example": |
| 418 | raise ValueError( |
| 419 | "`item` has length of 3 but last element is not for marking if the item is an `adversarial example`." |
| 420 | ) |
| 421 | else: |
| 422 | is_adv_sample.append(True) |
| 423 | else: |
| 424 | # else `len(item)` is 2. |
| 425 | _input, label = item |
| 426 | is_adv_sample.append(False) |
| 427 | |
| 428 | if isinstance(_input, collections.OrderedDict): |
| 429 | _input = tuple(_input.values()) |
| 430 | else: |
| 431 | _input = tuple(_input) |
| 432 | |
| 433 | if len(_input) == 1: |
| 434 | _input = _input[0] |
| 435 | input_texts.append(_input) |
| 436 | targets.append(label) |
| 437 | |
| 438 | return input_texts, torch.tensor(targets), torch.tensor(is_adv_sample) |
| 439 | |
| 440 | if adv_dataset: |
| 441 | dataset = torch.utils.data.ConcatDataset([dataset, adv_dataset]) |
| 442 | |
| 443 | train_dataloader = torch.utils.data.DataLoader( |
| 444 | dataset, |
| 445 | batch_size=batch_size, |