MCPcopy
hub / github.com/PaddlePaddle/PaddleNLP / prediction_step

Method prediction_step

llm/utils.py:199–254  ·  view source on GitHub ↗
(
        self,
        model,
        inputs,
        prediction_loss_only: bool,
        ignore_keys=None,
    )

Source from the content-addressed store, hash-verified

197 self.data_args = data_args
198
199 def prediction_step(
200 self,
201 model,
202 inputs,
203 prediction_loss_only: bool,
204 ignore_keys=None,
205 ):
206 if prediction_loss_only or self.args.pipeline_parallel_degree > 1:
207 return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
208 elif not self.do_generation:
209 loss, logits, labels = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
210 # argmax here to avoid gather all logits, which is too memory-consuming.
211 # keepdim in order to maintain the same shape as logits
212 if isinstance(logits, (list, tuple)):
213 logits = logits[0]
214 # all gather logits when enabling tensor_parallel_output
215 if self.args.tensor_parallel_degree > 1 and getattr(self.args, "tensor_parallel_output", False):
216 hcg = fleet.get_hybrid_communicate_group()
217 model_parallel_group = hcg.get_model_parallel_group()
218 gathered_logits = []
219 dist.all_gather(gathered_logits, logits, group=model_parallel_group)
220 logits = paddle.concat(gathered_logits, axis=-1)
221 return (loss, logits.argmax(axis=-1, keepdim=True), labels)
222
223 loss = None
224
225 model.eval()
226 with paddle.no_grad():
227 generated_tokens = model.generate(
228 input_ids=inputs["input_ids"],
229 attention_mask=inputs["attention_mask"] if "attention_mask" in inputs else None,
230 position_ids=inputs["position_ids"] if "position_ids" in inputs else None,
231 max_length=max(self.data_args.max_length - inputs["input_ids"].shape[-1], 1),
232 decode_strategy="sampling",
233 top_k=self.gen_args.top_k,
234 top_p=self.gen_args.top_p,
235 bos_token_id=self.tokenizer.bos_token_id,
236 eos_token_id=self.tokenizer.eos_token_id,
237 pad_token_id=self.tokenizer.pad_token_id,
238 use_cache=True,
239 )[0]
240 all_preds = []
241 for pred_tokens in generated_tokens:
242 pred_tokens = pred_tokens[pred_tokens != self.tokenizer.pad_token_id].tolist()
243 all_preds.append(pred_tokens)
244 max_pred_length = max([len(x) for x in all_preds])
245 for index, preds in enumerate(all_preds):
246 all_preds[index] = preds + [-100] * (max_pred_length - len(preds))
247 all_preds = paddle.to_tensor(all_preds)
248
249 if "labels" in inputs:
250 all_labels = paddle.to_tensor(inputs["labels"])
251 else:
252 all_labels = None
253
254 return (loss, all_preds, all_labels)
255
256 def log(self, logs: Dict[str, float], **kwargs) -> None:

Callers 1

ptq_loopMethod · 0.95

Calls 6

tolistMethod · 0.80
to_tensorMethod · 0.80
evalMethod · 0.45
generateMethod · 0.45
appendMethod · 0.45

Tested by

no test coverage detected