MCPcopy
hub / github.com/PaddlePaddle/PaddleNLP / DygraphBlockInferencePredictor

Class DygraphBlockInferencePredictor

llm/predictor.py:958–1018  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

956
957
958class DygraphBlockInferencePredictor(BlockInferencePredictorMixin, BasePredictor):
959 def __init__(
960 self,
961 config: PredictorArgument,
962 model: PretrainedModel = None,
963 tokenizer: PretrainedTokenizer = None,
964 ):
965 self.cache_kvs_shape = model.get_cache_kvs_shape(model.config, config.batch_size)
966 BasePredictor.__init__(self, config, tokenizer)
967 BlockInferencePredictorMixin.__init__(self, config, tokenizer)
968
969 if config.use_cachekv_int8 == "dynamic" or config.use_cachekv_int8 == "static":
970 self.cache_kvs = [paddle.zeros(shape, dtype="uint8") for shape in self.cache_kvs_shape]
971 else:
972 self.cache_kvs = [paddle.zeros(shape, dtype=self.dtype) for shape in self.cache_kvs_shape]
973
974 self.model = model
975
976 self.init_inputs(config)
977 if config.export_precache:
978 self.inputs["pre_caches"] = self.pre_caches
979 if config.use_cachekv_int8 == "dynamic":
980 self.inputs["k_quant_scales"] = self.k_quant_scales
981 self.inputs["v_quant_scales"] = self.v_quant_scales
982 self.inputs["k_dequant_scales"] = self.k_dequant_scales
983 self.inputs["v_dequant_scales"] = self.v_dequant_scales
984
985 self.inputs["cache_kvs"] = self.cache_kvs
986
987 @paddle.no_grad()
988 def _infer(self, inputs: dict[str, paddle.Tensor]):
989 self.model.generate(
990 **inputs,
991 )
992
993 @paddle.no_grad()
994 def predict(self, input_texts: str | list[str]):
995 self._preprocess(input_texts)
996
997 result_queue = mp.Queue()
998 tensor_queue = mp.Queue()
999
1000 output_tensor = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64")
1001 output_tensor = output_tensor.cpu()
1002 tensor_queue.put(output_tensor)
1003
1004 read_res_process = mp.Process(target=read_res, args=[self.model_name_or_path, tensor_queue, result_queue])
1005 read_res_process.start()
1006
1007 while self.inputs["not_need_stop"]:
1008 self._infer(self.inputs)
1009 # reset free_list
1010 for i in range(self.config.batch_size):
1011 self.free_list.extend(self.used_list[i])
1012 self.used_list[i] = []
1013 reset_stop_value(self.inputs["not_need_stop"])
1014
1015 outputs = []

Callers 1

create_predictorFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…