MCPcopy
hub / github.com/PaddlePaddle/PaddleNLP / StaticBlockInferencePredictor

Class StaticBlockInferencePredictor

llm/predictor.py:1021–1166  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

1019
1020
1021class StaticBlockInferencePredictor(BlockInferencePredictorMixin, BasePredictor):
1022 def __init__(
1023 self,
1024 config: PredictorArgument,
1025 cache_kvs_shape: list[list[int]],
1026 tokenizer: PretrainedTokenizer = None,
1027 ):
1028 self.cache_kvs_shape = cache_kvs_shape
1029 BasePredictor.__init__(self, config, tokenizer)
1030 BlockInferencePredictorMixin.__init__(self, config, tokenizer)
1031
1032 self.init_inputs(config)
1033
1034 if config.export_precache:
1035 for i in range(self.num_layers):
1036 self.inputs["pre_caches_{}".format(i)] = self.pre_caches[i]
1037
1038 self.cache_kvs = {}
1039 if config.use_cachekv_int8 == "dynamic" or config.use_cachekv_int8 == "static":
1040 for i in range(len(self.cache_kvs_shape) // 2):
1041 self.cache_kvs["key_caches_{}".format(i)] = paddle.zeros(self.cache_kvs_shape[2 * i], dtype="uint8")
1042 self.cache_kvs["value_caches_{}".format(i)] = paddle.zeros(
1043 self.cache_kvs_shape[2 * i + 1], dtype="uint8"
1044 )
1045 else:
1046 for i in range(len(self.cache_kvs_shape) // 2):
1047 self.cache_kvs["key_caches_{}".format(i)] = paddle.zeros(
1048 self.cache_kvs_shape[2 * i], dtype=config.dtype
1049 )
1050 self.cache_kvs["value_caches_{}".format(i)] = paddle.zeros(
1051 self.cache_kvs_shape[2 * i + 1], dtype=config.dtype
1052 )
1053
1054 for i in range(self.num_layers):
1055 if self.config.use_cachekv_int8 == "dynamic":
1056 self.inputs["k_quant_scales_" + str(i)] = self.k_quant_scales[i]
1057 self.inputs["v_quant_scales_" + str(i)] = self.v_quant_scales[i]
1058 self.inputs["k_dequant_scales_" + str(i)] = self.k_dequant_scales[i]
1059 self.inputs["v_dequant_scales_" + str(i)] = self.v_dequant_scales[i]
1060
1061 self._create_predictor(config)
1062 self.input_names = self.predictor.get_input_names()
1063
1064 self._share_data()
1065 self.seq_lens_handle = self.predictor.get_input_handle("seq_lens_this_time")
1066
1067 def _create_predictor(self, predictor_args: PredictorArgument):
1068 if not is_paddlenlp_ops_available():
1069 raise ValueError(
1070 "you should install the paddlenlp ops to run inference predictor, "
1071 "https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md"
1072 )
1073
1074 infer_model_path = get_infer_model_path(predictor_args.model_name_or_path, predictor_args.model_prefix)
1075
1076 config = paddle.inference.Config(infer_model_path + ".pdmodel", infer_model_path + ".pdiparams")
1077
1078 config.switch_ir_optim(False)

Callers 1

create_predictorFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…