Perform a single evaluation step on a batch of inputs. Args: model (:obj:`torch.nn.Module`): Model to train. tokenizer: Tokenizer used to tokenize input text. batch (:obj:`tuple[list[str], torch.Tensor]`): B
(self, model, tokenizer, batch)
| 549 | return loss, preds, _targets |
| 550 | |
| 551 | def evaluate_step(self, model, tokenizer, batch): |
| 552 | """Perform a single evaluation step on a batch of inputs. |
| 553 | |
| 554 | Args: |
| 555 | model (:obj:`torch.nn.Module`): |
| 556 | Model to train. |
| 557 | tokenizer: |
| 558 | Tokenizer used to tokenize input text. |
| 559 | batch (:obj:`tuple[list[str], torch.Tensor]`): |
| 560 | By default, this will be a tuple of input texts and target tensors. |
| 561 | |
| 562 | .. note:: |
| 563 | If you override the :meth:`get_eval_dataloader` method, then shape/type of :obj:`batch` will depend on how you created your batch. |
| 564 | |
| 565 | Returns: |
| 566 | :obj:`tuple[torch.Tensor, torch.Tensor]` where |
| 567 | |
| 568 | - **preds**: :obj:`torch.FloatTensor` of model's prediction for the batch. |
| 569 | - **targets**: :obj:`torch.Tensor` of model's targets (e.g. labels, target values). |
| 570 | """ |
| 571 | input_texts, targets = batch |
| 572 | _targets = targets |
| 573 | targets = targets.to(textattack.shared.utils.device) |
| 574 | |
| 575 | if isinstance(model, transformers.PreTrainedModel): |
| 576 | input_ids = tokenizer( |
| 577 | input_texts, |
| 578 | padding="max_length", |
| 579 | return_tensors="pt", |
| 580 | truncation=True, |
| 581 | ) |
| 582 | input_ids.to(textattack.shared.utils.device) |
| 583 | logits = model(**input_ids)[0] |
| 584 | else: |
| 585 | input_ids = tokenizer(input_texts) |
| 586 | if not isinstance(input_ids, torch.Tensor): |
| 587 | input_ids = torch.tensor(input_ids) |
| 588 | input_ids = input_ids.to(textattack.shared.utils.device) |
| 589 | logits = model(input_ids) |
| 590 | |
| 591 | if self.task_type == "regression": |
| 592 | preds = logits |
| 593 | else: |
| 594 | preds = logits.argmax(dim=-1) |
| 595 | |
| 596 | return preds.cpu(), _targets |
| 597 | |
| 598 | def train(self): |
| 599 | """Train the model on given training dataset.""" |