MCPcopy
hub / github.com/zai-org/ChatGLM2-6B / predict

Method predict

ptuning/trainer_seq2seq.py:80–136  ·  view source on GitHub ↗

Run prediction and returns predictions and potential metrics. Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method will also return metrics, like in `evaluate()`. Args: test_dataset (`Dataset`):

(
        self,
        test_dataset: Dataset,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "test",
        **gen_kwargs
    )

Source from the content-addressed store, hash-verified

78 return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
79
80 def predict(
81 self,
82 test_dataset: Dataset,
83 ignore_keys: Optional[List[str]] = None,
84 metric_key_prefix: str = "test",
85 **gen_kwargs
86 ) -> PredictionOutput:
87 """
88 Run prediction and returns predictions and potential metrics.
89
90 Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
91 will also return metrics, like in `evaluate()`.
92
93 Args:
94 test_dataset (`Dataset`):
95 Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the
96 `model.forward()` method are automatically removed. Has to implement the method `__len__`
97 ignore_keys (`List[str]`, *optional*):
98 A list of keys in the output of your model (if it is a dictionary) that should be ignored when
99 gathering predictions.
100 metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
101 An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
102 "eval_bleu" if the prefix is `"eval"` (default)
103 max_length (`int`, *optional*):
104 The maximum target length to use when predicting with the generate method.
105 num_beams (`int`, *optional*):
106 Number of beams for beam search that will be used when predicting with the generate method. 1 means no
107 beam search.
108 gen_kwargs:
109 Additional `generate` specific kwargs.
110
111 <Tip>
112
113 If your predictions or labels have different sequence lengths (for instance because you&#x27;re doing dynamic
114 padding in a token classification task) the predictions will be padded (on the right) to allow for
115 concatenation into one array. The padding index is -100.
116
117 </Tip>
118
119 Returns: *NamedTuple* A namedtuple with the following keys:
120
121 - predictions (`np.ndarray`): The predictions on `test_dataset`.
122 - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
123 - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
124 labels).
125 """
126
127 gen_kwargs = gen_kwargs.copy()
128 if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
129 gen_kwargs["max_length"] = self.args.generation_max_length
130 gen_kwargs["num_beams"] = (
131 gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
132 )
133 self._gen_kwargs = gen_kwargs
134
135
136 return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
137

Callers 1

mainFunction · 0.95

Calls

no outgoing calls

Tested by

no test coverage detected