MCPcopy
hub / github.com/QData/TextAttack / training_step

Method training_step

textattack/trainer.py:489–549  ·  view source on GitHub ↗

Perform a single training step on a batch of inputs. Args: model (:obj:`torch.nn.Module`): Model to train. tokenizer: Tokenizer used to tokenize input text. batch (:obj:`tuple[list[str], torch.Tensor, torch.Tensor]`):

(self, model, tokenizer, batch)

Source from the content-addressed store, hash-verified

487 return eval_dataloader
488
489 def training_step(self, model, tokenizer, batch):
490 """Perform a single training step on a batch of inputs.
491
492 Args:
493 model (:obj:`torch.nn.Module`):
494 Model to train.
495 tokenizer:
496 Tokenizer used to tokenize input text.
497 batch (:obj:`tuple[list[str], torch.Tensor, torch.Tensor]`):
498 By default, this will be a tuple of input texts, targets, and boolean tensor indicating if the sample is an adversarial example.
499
500 .. note::
501 If you override the :meth:`get_train_dataloader` method, then shape/type of :obj:`batch` will depend on how you created your batch.
502
503 Returns:
504 :obj:`tuple[torch.Tensor, torch.Tensor, torch.Tensor]` where
505
506 - **loss**: :obj:`torch.FloatTensor` of shape 1 containing the loss.
507 - **preds**: :obj:`torch.FloatTensor` of model's prediction for the batch.
508 - **targets**: :obj:`torch.Tensor` of model's targets (e.g. labels, target values).
509 """
510
511 input_texts, targets, is_adv_sample = batch
512 _targets = targets
513 targets = targets.to(textattack.shared.utils.device)
514
515 if isinstance(model, transformers.PreTrainedModel) or (
516 isinstance(model, torch.nn.DataParallel)
517 and isinstance(model.module, transformers.PreTrainedModel)
518 ):
519 input_ids = tokenizer(
520 input_texts,
521 padding="max_length",
522 return_tensors="pt",
523 truncation=True,
524 )
525 input_ids.to(textattack.shared.utils.device)
526 logits = model(**input_ids)[0]
527 else:
528 input_ids = tokenizer(input_texts)
529 if not isinstance(input_ids, torch.Tensor):
530 input_ids = torch.tensor(input_ids)
531 input_ids = input_ids.to(textattack.shared.utils.device)
532 logits = model(input_ids)
533
534 if self.task_type == "regression":
535 loss = self.loss_fct(logits.squeeze(), targets.squeeze())
536 preds = logits
537 else:
538 loss = self.loss_fct(logits, targets)
539 preds = logits.argmax(dim=-1)
540
541 sample_weights = torch.ones(
542 is_adv_sample.size(), device=textattack.shared.utils.device
543 )
544 sample_weights[is_adv_sample] *= self.training_args.alpha
545 loss = loss * sample_weights
546 loss = torch.mean(loss)

Callers 1

trainMethod · 0.95

Calls 2

toMethod · 0.80
sizeMethod · 0.80

Tested by

no test coverage detected