Returns the :obj:`torch.utils.data.DataLoader` for evaluation. Args: dataset (:class:`~textattack.datasets.Dataset`): Dataset to use for evaluation. batch_size (:obj:`int`): Batch size for evaluation. Returns: :obj:
(self, dataset, batch_size)
| 450 | return train_dataloader |
| 451 | |
| 452 | def get_eval_dataloader(self, dataset, batch_size): |
| 453 | """Returns the :obj:`torch.utils.data.DataLoader` for evaluation. |
| 454 | |
| 455 | Args: |
| 456 | dataset (:class:`~textattack.datasets.Dataset`): |
| 457 | Dataset to use for evaluation. |
| 458 | batch_size (:obj:`int`): |
| 459 | Batch size for evaluation. |
| 460 | Returns: |
| 461 | :obj:`torch.utils.data.DataLoader` |
| 462 | """ |
| 463 | |
| 464 | # Helper functions for collating data |
| 465 | def collate_fn(data): |
| 466 | input_texts = [] |
| 467 | targets = [] |
| 468 | for _input, label in data: |
| 469 | if isinstance(_input, collections.OrderedDict): |
| 470 | _input = tuple(_input.values()) |
| 471 | else: |
| 472 | _input = tuple(_input) |
| 473 | |
| 474 | if len(_input) == 1: |
| 475 | _input = _input[0] |
| 476 | input_texts.append(_input) |
| 477 | targets.append(label) |
| 478 | return input_texts, torch.tensor(targets) |
| 479 | |
| 480 | eval_dataloader = torch.utils.data.DataLoader( |
| 481 | dataset, |
| 482 | batch_size=batch_size, |
| 483 | shuffle=True, |
| 484 | collate_fn=collate_fn, |
| 485 | pin_memory=True, |
| 486 | ) |
| 487 | return eval_dataloader |
| 488 | |
| 489 | def training_step(self, model, tokenizer, batch): |
| 490 | """Perform a single training step on a batch of inputs. |