MCPcopy
hub / github.com/PaddlePaddle/PaddleNLP / StaticGraphPredictor

Class StaticGraphPredictor

llm/predictor.py:347–395  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

345
346
347class StaticGraphPredictor(BasePredictor):
348 def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None):
349 super().__init__(config, tokenizer)
350
351 params_path = os.path.join(self.config.model_name_or_path, self.config.model_prefix + ".pdiparams")
352 model_path = os.path.join(self.config.model_name_or_path, self.config.model_prefix + ".pdmodel")
353 inference_config = paddle.inference.Config(model_path, params_path)
354
355 if self.config.device == "gpu":
356 # set GPU configs accordingly
357 inference_config.enable_use_gpu(100, 0)
358 elif self.config.device == "cpu":
359 # set CPU configs accordingly,
360 # such as enable_mkldnn, set_cpu_math_library_num_threads
361 inference_config.disable_gpu()
362 inference_config.disable_glog_info()
363 inference_config.enable_new_executor()
364 if in_pir_executor_mode():
365 inference_config.enable_new_ir()
366 if in_cinn_mode():
367 inference_config.enable_cinn()
368
369 with static_mode_guard():
370 self.predictor = paddle.inference.create_predictor(inference_config)
371
372 self.return_tensors = "np"
373
374 def _preprocess(self, input_text: str | list[str]):
375 inputs = super()._preprocess(input_text)
376 inputs["max_new_tokens"] = np.array(self.config.max_length, dtype="int64")
377
378 inputs["top_p"] = np.array(self.config.top_p, dtype="float32")
379 inputs["temperature"] = np.array(self.config.temperature, dtype="float32")
380 inputs["top_k"] = np.array(self.config.top_k, dtype="int64")
381 inputs["repetition_penalty"] = np.array(self.config.repetition_penalty, dtype="float32")
382
383 return inputs
384
385 def _infer(self, inputs: dict[str, np.ndarray]):
386 for name in self.predictor.get_input_names():
387 self.predictor.get_input_handle(name).copy_from_cpu(inputs[name])
388
389 self.predictor.run()
390 output_names = self.predictor.get_output_names()
391 output_handle = self.predictor.get_output_handle(output_names[0])
392 results = output_handle.copy_to_cpu()
393 # the first result is decoding_ids
394 decoded_ids = results.tolist()
395 return decoded_ids
396
397
398class InferencePredictorMixin:

Callers 1

create_predictorFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…