MCPcopy
hub / github.com/PaddlePaddle/PaddleNLP / DygraphInferencePredictor

Class DygraphInferencePredictor

llm/predictor.py:699–727  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

697
698
699class DygraphInferencePredictor(InferencePredictorMixin, BasePredictor):
700 def __init__(
701 self,
702 config: PredictorArgument,
703 model: PretrainedModel = None,
704 tokenizer: PretrainedTokenizer = None,
705 ):
706 self.cache_kvs_shape = model.get_cache_kvs_shape(model.config, config.batch_size, config.total_max_length)
707 BasePredictor.__init__(self, config, tokenizer)
708 InferencePredictorMixin.__init__(self, config, tokenizer)
709 self.model = model
710
711 @paddle.no_grad()
712 def _infer(self, inputs: dict[str, paddle.Tensor]):
713 for key in inputs.keys():
714 if paddle.is_tensor(inputs[key]):
715 continue
716 if isinstance(inputs[key], list):
717 if paddle.is_tensor(inputs[key]):
718 continue
719 inputs[key] = [paddle.to_tensor(item) for item in inputs[key]]
720 else:
721 inputs[key] = paddle.to_tensor(inputs[key])
722
723 inputs["cache_kvs"] = self.cache_kvs
724 self.model.generate(
725 **inputs,
726 )
727 return None
728
729
730class BlockInferencePredictorMixin:

Callers 1

create_predictorFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…