| 697 | |
| 698 | |
| 699 | class DygraphInferencePredictor(InferencePredictorMixin, BasePredictor): |
| 700 | def __init__( |
| 701 | self, |
| 702 | config: PredictorArgument, |
| 703 | model: PretrainedModel = None, |
| 704 | tokenizer: PretrainedTokenizer = None, |
| 705 | ): |
| 706 | self.cache_kvs_shape = model.get_cache_kvs_shape(model.config, config.batch_size, config.total_max_length) |
| 707 | BasePredictor.__init__(self, config, tokenizer) |
| 708 | InferencePredictorMixin.__init__(self, config, tokenizer) |
| 709 | self.model = model |
| 710 | |
| 711 | @paddle.no_grad() |
| 712 | def _infer(self, inputs: dict[str, paddle.Tensor]): |
| 713 | for key in inputs.keys(): |
| 714 | if paddle.is_tensor(inputs[key]): |
| 715 | continue |
| 716 | if isinstance(inputs[key], list): |
| 717 | if paddle.is_tensor(inputs[key]): |
| 718 | continue |
| 719 | inputs[key] = [paddle.to_tensor(item) for item in inputs[key]] |
| 720 | else: |
| 721 | inputs[key] = paddle.to_tensor(inputs[key]) |
| 722 | |
| 723 | inputs["cache_kvs"] = self.cache_kvs |
| 724 | self.model.generate( |
| 725 | **inputs, |
| 726 | ) |
| 727 | return None |
| 728 | |
| 729 | |
| 730 | class BlockInferencePredictorMixin: |
no outgoing calls
no test coverage detected
searching dependent graphs…