MCPcopy
hub / github.com/QData/TextAttack / get_eval_dataloader

Method get_eval_dataloader

textattack/trainer.py:452–487  ·  view source on GitHub ↗

Returns the :obj:`torch.utils.data.DataLoader` for evaluation. Args: dataset (:class:`~textattack.datasets.Dataset`): Dataset to use for evaluation. batch_size (:obj:`int`): Batch size for evaluation. Returns: :obj:

(self, dataset, batch_size)

Source from the content-addressed store, hash-verified

450 return train_dataloader
451
452 def get_eval_dataloader(self, dataset, batch_size):
453 """Returns the :obj:`torch.utils.data.DataLoader` for evaluation.
454
455 Args:
456 dataset (:class:`~textattack.datasets.Dataset`):
457 Dataset to use for evaluation.
458 batch_size (:obj:`int`):
459 Batch size for evaluation.
460 Returns:
461 :obj:`torch.utils.data.DataLoader`
462 """
463
464 # Helper functions for collating data
465 def collate_fn(data):
466 input_texts = []
467 targets = []
468 for _input, label in data:
469 if isinstance(_input, collections.OrderedDict):
470 _input = tuple(_input.values())
471 else:
472 _input = tuple(_input)
473
474 if len(_input) == 1:
475 _input = _input[0]
476 input_texts.append(_input)
477 targets.append(label)
478 return input_texts, torch.tensor(targets)
479
480 eval_dataloader = torch.utils.data.DataLoader(
481 dataset,
482 batch_size=batch_size,
483 shuffle=True,
484 collate_fn=collate_fn,
485 pin_memory=True,
486 )
487 return eval_dataloader
488
489 def training_step(self, model, tokenizer, batch):
490 """Perform a single training step on a batch of inputs.

Callers 1

evaluateMethod · 0.95

Calls

no outgoing calls

Tested by

no test coverage detected