MCPcopy
hub / github.com/PaddlePaddle/PaddleNLP / __init__

Method __init__

llm/predictor.py:267–304  ·  view source on GitHub ↗
(
        self, config: PredictorArgument, model: PretrainedModel = None, tokenizer: PretrainedTokenizer = None
    )

Source from the content-addressed store, hash-verified

265
266class DygraphPredictor(BasePredictor):
267 def __init__(
268 self, config: PredictorArgument, model: PretrainedModel = None, tokenizer: PretrainedTokenizer = None
269 ):
270 super().__init__(config, tokenizer)
271 self.model = model
272 if config.lora_path is not None:
273 lora_config = LoRAConfig.from_pretrained(config.lora_path)
274 dtype = lora_config.dtype
275 lora_config.merge_weights = True
276 elif config.prefix_path is not None:
277 prefix_config = PrefixConfig.from_pretrained(config.prefix_path)
278 dtype = prefix_config.dtype
279 elif config.dtype is not None:
280 dtype = config.dtype
281 else:
282 raise ValueError("Please specific the model dtype.")
283
284 if self.model is None:
285 self.model = AutoModelForCausalLM.from_pretrained(
286 config.model_name_or_path,
287 use_flash_attention=config.use_flash_attention,
288 dtype=dtype,
289 tensor_parallel_degree=self.tensor_parallel_degree,
290 tensor_parallel_rank=self.tensor_parallel_rank,
291 )
292
293 if config.lora_path is not None:
294 self.model = LoRAModel.from_pretrained(
295 model=self.model, lora_path=config.lora_path, lora_config=lora_config
296 )
297 if config.prefix_path is not None:
298 prefix_tuning_params = get_prefix_tuning_params(self.model)
299 self.model = PrefixModelForCausalLM.from_pretrained(
300 model=self.model,
301 prefix_path=config.prefix_path,
302 postprocess_past_key_value=prefix_tuning_params["postprocess_past_key_value"],
303 )
304 self.model.eval()
305
306 @paddle.no_grad()
307 def _infer(self, inputs: dict[str, paddle.Tensor]):

Callers

nothing calls this directly

Calls 4

get_prefix_tuning_paramsFunction · 0.90
__init__Method · 0.45
from_pretrainedMethod · 0.45
evalMethod · 0.45

Tested by

no test coverage detected