| 1019 | |
| 1020 | |
| 1021 | class StaticBlockInferencePredictor(BlockInferencePredictorMixin, BasePredictor): |
| 1022 | def __init__( |
| 1023 | self, |
| 1024 | config: PredictorArgument, |
| 1025 | cache_kvs_shape: list[list[int]], |
| 1026 | tokenizer: PretrainedTokenizer = None, |
| 1027 | ): |
| 1028 | self.cache_kvs_shape = cache_kvs_shape |
| 1029 | BasePredictor.__init__(self, config, tokenizer) |
| 1030 | BlockInferencePredictorMixin.__init__(self, config, tokenizer) |
| 1031 | |
| 1032 | self.init_inputs(config) |
| 1033 | |
| 1034 | if config.export_precache: |
| 1035 | for i in range(self.num_layers): |
| 1036 | self.inputs["pre_caches_{}".format(i)] = self.pre_caches[i] |
| 1037 | |
| 1038 | self.cache_kvs = {} |
| 1039 | if config.use_cachekv_int8 == "dynamic" or config.use_cachekv_int8 == "static": |
| 1040 | for i in range(len(self.cache_kvs_shape) // 2): |
| 1041 | self.cache_kvs["key_caches_{}".format(i)] = paddle.zeros(self.cache_kvs_shape[2 * i], dtype="uint8") |
| 1042 | self.cache_kvs["value_caches_{}".format(i)] = paddle.zeros( |
| 1043 | self.cache_kvs_shape[2 * i + 1], dtype="uint8" |
| 1044 | ) |
| 1045 | else: |
| 1046 | for i in range(len(self.cache_kvs_shape) // 2): |
| 1047 | self.cache_kvs["key_caches_{}".format(i)] = paddle.zeros( |
| 1048 | self.cache_kvs_shape[2 * i], dtype=config.dtype |
| 1049 | ) |
| 1050 | self.cache_kvs["value_caches_{}".format(i)] = paddle.zeros( |
| 1051 | self.cache_kvs_shape[2 * i + 1], dtype=config.dtype |
| 1052 | ) |
| 1053 | |
| 1054 | for i in range(self.num_layers): |
| 1055 | if self.config.use_cachekv_int8 == "dynamic": |
| 1056 | self.inputs["k_quant_scales_" + str(i)] = self.k_quant_scales[i] |
| 1057 | self.inputs["v_quant_scales_" + str(i)] = self.v_quant_scales[i] |
| 1058 | self.inputs["k_dequant_scales_" + str(i)] = self.k_dequant_scales[i] |
| 1059 | self.inputs["v_dequant_scales_" + str(i)] = self.v_dequant_scales[i] |
| 1060 | |
| 1061 | self._create_predictor(config) |
| 1062 | self.input_names = self.predictor.get_input_names() |
| 1063 | |
| 1064 | self._share_data() |
| 1065 | self.seq_lens_handle = self.predictor.get_input_handle("seq_lens_this_time") |
| 1066 | |
| 1067 | def _create_predictor(self, predictor_args: PredictorArgument): |
| 1068 | if not is_paddlenlp_ops_available(): |
| 1069 | raise ValueError( |
| 1070 | "you should install the paddlenlp ops to run inference predictor, " |
| 1071 | "https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md" |
| 1072 | ) |
| 1073 | |
| 1074 | infer_model_path = get_infer_model_path(predictor_args.model_name_or_path, predictor_args.model_prefix) |
| 1075 | |
| 1076 | config = paddle.inference.Config(infer_model_path + ".pdmodel", infer_model_path + ".pdiparams") |
| 1077 | |
| 1078 | config.switch_ir_optim(False) |
no outgoing calls
no test coverage detected
searching dependent graphs…