(
self,
config: PredictorArgument,
model: PretrainedModel = None,
tokenizer: PretrainedTokenizer = None,
)
| 957 | |
| 958 | class DygraphBlockInferencePredictor(BlockInferencePredictorMixin, BasePredictor): |
| 959 | def __init__( |
| 960 | self, |
| 961 | config: PredictorArgument, |
| 962 | model: PretrainedModel = None, |
| 963 | tokenizer: PretrainedTokenizer = None, |
| 964 | ): |
| 965 | self.cache_kvs_shape = model.get_cache_kvs_shape(model.config, config.batch_size) |
| 966 | BasePredictor.__init__(self, config, tokenizer) |
| 967 | BlockInferencePredictorMixin.__init__(self, config, tokenizer) |
| 968 | |
| 969 | if config.use_cachekv_int8 == "dynamic" or config.use_cachekv_int8 == "static": |
| 970 | self.cache_kvs = [paddle.zeros(shape, dtype="uint8") for shape in self.cache_kvs_shape] |
| 971 | else: |
| 972 | self.cache_kvs = [paddle.zeros(shape, dtype=self.dtype) for shape in self.cache_kvs_shape] |
| 973 | |
| 974 | self.model = model |
| 975 | |
| 976 | self.init_inputs(config) |
| 977 | if config.export_precache: |
| 978 | self.inputs["pre_caches"] = self.pre_caches |
| 979 | if config.use_cachekv_int8 == "dynamic": |
| 980 | self.inputs["k_quant_scales"] = self.k_quant_scales |
| 981 | self.inputs["v_quant_scales"] = self.v_quant_scales |
| 982 | self.inputs["k_dequant_scales"] = self.k_dequant_scales |
| 983 | self.inputs["v_dequant_scales"] = self.v_dequant_scales |
| 984 | |
| 985 | self.inputs["cache_kvs"] = self.cache_kvs |
| 986 | |
| 987 | @paddle.no_grad() |
| 988 | def _infer(self, inputs: dict[str, paddle.Tensor]): |
nothing calls this directly
no test coverage detected