MCPcopy
hub / github.com/QData/TextAttack / evaluate_step

Method evaluate_step

textattack/trainer.py:551–596  ·  view source on GitHub ↗

Perform a single evaluation step on a batch of inputs. Args: model (:obj:`torch.nn.Module`): Model to train. tokenizer: Tokenizer used to tokenize input text. batch (:obj:`tuple[list[str], torch.Tensor]`): B

(self, model, tokenizer, batch)

Source from the content-addressed store, hash-verified

549 return loss, preds, _targets
550
551 def evaluate_step(self, model, tokenizer, batch):
552 """Perform a single evaluation step on a batch of inputs.
553
554 Args:
555 model (:obj:`torch.nn.Module`):
556 Model to train.
557 tokenizer:
558 Tokenizer used to tokenize input text.
559 batch (:obj:`tuple[list[str], torch.Tensor]`):
560 By default, this will be a tuple of input texts and target tensors.
561
562 .. note::
563 If you override the :meth:`get_eval_dataloader` method, then shape/type of :obj:`batch` will depend on how you created your batch.
564
565 Returns:
566 :obj:`tuple[torch.Tensor, torch.Tensor]` where
567
568 - **preds**: :obj:`torch.FloatTensor` of model's prediction for the batch.
569 - **targets**: :obj:`torch.Tensor` of model's targets (e.g. labels, target values).
570 """
571 input_texts, targets = batch
572 _targets = targets
573 targets = targets.to(textattack.shared.utils.device)
574
575 if isinstance(model, transformers.PreTrainedModel):
576 input_ids = tokenizer(
577 input_texts,
578 padding="max_length",
579 return_tensors="pt",
580 truncation=True,
581 )
582 input_ids.to(textattack.shared.utils.device)
583 logits = model(**input_ids)[0]
584 else:
585 input_ids = tokenizer(input_texts)
586 if not isinstance(input_ids, torch.Tensor):
587 input_ids = torch.tensor(input_ids)
588 input_ids = input_ids.to(textattack.shared.utils.device)
589 logits = model(input_ids)
590
591 if self.task_type == "regression":
592 preds = logits
593 else:
594 preds = logits.argmax(dim=-1)
595
596 return preds.cpu(), _targets
597
598 def train(self):
599 """Train the model on given training dataset."""

Callers 1

evaluateMethod · 0.95

Calls 1

toMethod · 0.80

Tested by

no test coverage detected