Perform a single training step on a batch of inputs. Args: model (:obj:`torch.nn.Module`): Model to train. tokenizer: Tokenizer used to tokenize input text. batch (:obj:`tuple[list[str], torch.Tensor, torch.Tensor]`):
(self, model, tokenizer, batch)
| 487 | return eval_dataloader |
| 488 | |
| 489 | def training_step(self, model, tokenizer, batch): |
| 490 | """Perform a single training step on a batch of inputs. |
| 491 | |
| 492 | Args: |
| 493 | model (:obj:`torch.nn.Module`): |
| 494 | Model to train. |
| 495 | tokenizer: |
| 496 | Tokenizer used to tokenize input text. |
| 497 | batch (:obj:`tuple[list[str], torch.Tensor, torch.Tensor]`): |
| 498 | By default, this will be a tuple of input texts, targets, and boolean tensor indicating if the sample is an adversarial example. |
| 499 | |
| 500 | .. note:: |
| 501 | If you override the :meth:`get_train_dataloader` method, then shape/type of :obj:`batch` will depend on how you created your batch. |
| 502 | |
| 503 | Returns: |
| 504 | :obj:`tuple[torch.Tensor, torch.Tensor, torch.Tensor]` where |
| 505 | |
| 506 | - **loss**: :obj:`torch.FloatTensor` of shape 1 containing the loss. |
| 507 | - **preds**: :obj:`torch.FloatTensor` of model's prediction for the batch. |
| 508 | - **targets**: :obj:`torch.Tensor` of model's targets (e.g. labels, target values). |
| 509 | """ |
| 510 | |
| 511 | input_texts, targets, is_adv_sample = batch |
| 512 | _targets = targets |
| 513 | targets = targets.to(textattack.shared.utils.device) |
| 514 | |
| 515 | if isinstance(model, transformers.PreTrainedModel) or ( |
| 516 | isinstance(model, torch.nn.DataParallel) |
| 517 | and isinstance(model.module, transformers.PreTrainedModel) |
| 518 | ): |
| 519 | input_ids = tokenizer( |
| 520 | input_texts, |
| 521 | padding="max_length", |
| 522 | return_tensors="pt", |
| 523 | truncation=True, |
| 524 | ) |
| 525 | input_ids.to(textattack.shared.utils.device) |
| 526 | logits = model(**input_ids)[0] |
| 527 | else: |
| 528 | input_ids = tokenizer(input_texts) |
| 529 | if not isinstance(input_ids, torch.Tensor): |
| 530 | input_ids = torch.tensor(input_ids) |
| 531 | input_ids = input_ids.to(textattack.shared.utils.device) |
| 532 | logits = model(input_ids) |
| 533 | |
| 534 | if self.task_type == "regression": |
| 535 | loss = self.loss_fct(logits.squeeze(), targets.squeeze()) |
| 536 | preds = logits |
| 537 | else: |
| 538 | loss = self.loss_fct(logits, targets) |
| 539 | preds = logits.argmax(dim=-1) |
| 540 | |
| 541 | sample_weights = torch.ones( |
| 542 | is_adv_sample.size(), device=textattack.shared.utils.device |
| 543 | ) |
| 544 | sample_weights[is_adv_sample] *= self.training_args.alpha |
| 545 | loss = loss * sample_weights |
| 546 | loss = torch.mean(loss) |