| 956 | |
| 957 | |
| 958 | class DygraphBlockInferencePredictor(BlockInferencePredictorMixin, BasePredictor): |
| 959 | def __init__( |
| 960 | self, |
| 961 | config: PredictorArgument, |
| 962 | model: PretrainedModel = None, |
| 963 | tokenizer: PretrainedTokenizer = None, |
| 964 | ): |
| 965 | self.cache_kvs_shape = model.get_cache_kvs_shape(model.config, config.batch_size) |
| 966 | BasePredictor.__init__(self, config, tokenizer) |
| 967 | BlockInferencePredictorMixin.__init__(self, config, tokenizer) |
| 968 | |
| 969 | if config.use_cachekv_int8 == "dynamic" or config.use_cachekv_int8 == "static": |
| 970 | self.cache_kvs = [paddle.zeros(shape, dtype="uint8") for shape in self.cache_kvs_shape] |
| 971 | else: |
| 972 | self.cache_kvs = [paddle.zeros(shape, dtype=self.dtype) for shape in self.cache_kvs_shape] |
| 973 | |
| 974 | self.model = model |
| 975 | |
| 976 | self.init_inputs(config) |
| 977 | if config.export_precache: |
| 978 | self.inputs["pre_caches"] = self.pre_caches |
| 979 | if config.use_cachekv_int8 == "dynamic": |
| 980 | self.inputs["k_quant_scales"] = self.k_quant_scales |
| 981 | self.inputs["v_quant_scales"] = self.v_quant_scales |
| 982 | self.inputs["k_dequant_scales"] = self.k_dequant_scales |
| 983 | self.inputs["v_dequant_scales"] = self.v_dequant_scales |
| 984 | |
| 985 | self.inputs["cache_kvs"] = self.cache_kvs |
| 986 | |
| 987 | @paddle.no_grad() |
| 988 | def _infer(self, inputs: dict[str, paddle.Tensor]): |
| 989 | self.model.generate( |
| 990 | **inputs, |
| 991 | ) |
| 992 | |
| 993 | @paddle.no_grad() |
| 994 | def predict(self, input_texts: str | list[str]): |
| 995 | self._preprocess(input_texts) |
| 996 | |
| 997 | result_queue = mp.Queue() |
| 998 | tensor_queue = mp.Queue() |
| 999 | |
| 1000 | output_tensor = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64") |
| 1001 | output_tensor = output_tensor.cpu() |
| 1002 | tensor_queue.put(output_tensor) |
| 1003 | |
| 1004 | read_res_process = mp.Process(target=read_res, args=[self.model_name_or_path, tensor_queue, result_queue]) |
| 1005 | read_res_process.start() |
| 1006 | |
| 1007 | while self.inputs["not_need_stop"]: |
| 1008 | self._infer(self.inputs) |
| 1009 | # reset free_list |
| 1010 | for i in range(self.config.batch_size): |
| 1011 | self.free_list.extend(self.used_list[i]) |
| 1012 | self.used_list[i] = [] |
| 1013 | reset_stop_value(self.inputs["not_need_stop"]) |
| 1014 | |
| 1015 | outputs = [] |
no outgoing calls
no test coverage detected
searching dependent graphs…