MCPcopy
hub / github.com/PaddlePaddle/PaddleNLP / __init__

Method __init__

llm/predictor.py:959–985  ·  view source on GitHub ↗
(
        self,
        config: PredictorArgument,
        model: PretrainedModel = None,
        tokenizer: PretrainedTokenizer = None,
    )

Source from the content-addressed store, hash-verified

957
958class DygraphBlockInferencePredictor(BlockInferencePredictorMixin, BasePredictor):
959 def __init__(
960 self,
961 config: PredictorArgument,
962 model: PretrainedModel = None,
963 tokenizer: PretrainedTokenizer = None,
964 ):
965 self.cache_kvs_shape = model.get_cache_kvs_shape(model.config, config.batch_size)
966 BasePredictor.__init__(self, config, tokenizer)
967 BlockInferencePredictorMixin.__init__(self, config, tokenizer)
968
969 if config.use_cachekv_int8 == "dynamic" or config.use_cachekv_int8 == "static":
970 self.cache_kvs = [paddle.zeros(shape, dtype="uint8") for shape in self.cache_kvs_shape]
971 else:
972 self.cache_kvs = [paddle.zeros(shape, dtype=self.dtype) for shape in self.cache_kvs_shape]
973
974 self.model = model
975
976 self.init_inputs(config)
977 if config.export_precache:
978 self.inputs["pre_caches"] = self.pre_caches
979 if config.use_cachekv_int8 == "dynamic":
980 self.inputs["k_quant_scales"] = self.k_quant_scales
981 self.inputs["v_quant_scales"] = self.v_quant_scales
982 self.inputs["k_dequant_scales"] = self.k_dequant_scales
983 self.inputs["v_dequant_scales"] = self.v_dequant_scales
984
985 self.inputs["cache_kvs"] = self.cache_kvs
986
987 @paddle.no_grad()
988 def _infer(self, inputs: dict[str, paddle.Tensor]):

Callers

nothing calls this directly

Calls 3

init_inputsMethod · 0.80
get_cache_kvs_shapeMethod · 0.45
__init__Method · 0.45

Tested by

no test coverage detected