MCPcopy Index your code
hub / github.com/QData/TextAttack / get_train_dataloader

Method get_train_dataloader

textattack/trainer.py:388–450  ·  view source on GitHub ↗

Returns the :obj:`torch.utils.data.DataLoader` for training. Args: dataset (:class:`~textattack.datasets.Dataset`): Original training dataset. adv_dataset (:class:`~textattack.datasets.Dataset`): Adversarial examples generated from the

(self, dataset, adv_dataset, batch_size)

Source from the content-addressed store, hash-verified

386 return optimizer, scheduler
387
388 def get_train_dataloader(self, dataset, adv_dataset, batch_size):
389 """Returns the :obj:`torch.utils.data.DataLoader` for training.
390
391 Args:
392 dataset (:class:`~textattack.datasets.Dataset`):
393 Original training dataset.
394 adv_dataset (:class:`~textattack.datasets.Dataset`):
395 Adversarial examples generated from the original training dataset. :obj:`None` if no adversarial attack takes place.
396 batch_size (:obj:`int`):
397 Batch size for training.
398 Returns:
399 :obj:`torch.utils.data.DataLoader`
400 """
401
402 # TODO: Add pairing option where we can pair original examples with adversarial examples.
403 # Helper functions for collating data
404 def collate_fn(data):
405 input_texts = []
406 targets = []
407 is_adv_sample = []
408 for item in data:
409 if "_example_type" in item[0].keys():
410 # Get example type value from OrderedDict and remove it
411
412 adv = item[0].pop("_example_type")
413
414 # with _example_type removed from item[0] OrderedDict
415 # all other keys should be part of input
416 _input, label = item
417 if adv != "adversarial_example":
418 raise ValueError(
419 "`item` has length of 3 but last element is not for marking if the item is an `adversarial example`."
420 )
421 else:
422 is_adv_sample.append(True)
423 else:
424 # else `len(item)` is 2.
425 _input, label = item
426 is_adv_sample.append(False)
427
428 if isinstance(_input, collections.OrderedDict):
429 _input = tuple(_input.values())
430 else:
431 _input = tuple(_input)
432
433 if len(_input) == 1:
434 _input = _input[0]
435 input_texts.append(_input)
436 targets.append(label)
437
438 return input_texts, torch.tensor(targets), torch.tensor(is_adv_sample)
439
440 if adv_dataset:
441 dataset = torch.utils.data.ConcatDataset([dataset, adv_dataset])
442
443 train_dataloader = torch.utils.data.DataLoader(
444 dataset,
445 batch_size=batch_size,

Callers 1

trainMethod · 0.95

Calls

no outgoing calls

Tested by

no test coverage detected