MCPcopy
hub / github.com/PaddlePaddle/PaddleNLP / DygraphPredictor

Class DygraphPredictor

llm/predictor.py:266–344  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

264
265
266class DygraphPredictor(BasePredictor):
267 def __init__(
268 self, config: PredictorArgument, model: PretrainedModel = None, tokenizer: PretrainedTokenizer = None
269 ):
270 super().__init__(config, tokenizer)
271 self.model = model
272 if config.lora_path is not None:
273 lora_config = LoRAConfig.from_pretrained(config.lora_path)
274 dtype = lora_config.dtype
275 lora_config.merge_weights = True
276 elif config.prefix_path is not None:
277 prefix_config = PrefixConfig.from_pretrained(config.prefix_path)
278 dtype = prefix_config.dtype
279 elif config.dtype is not None:
280 dtype = config.dtype
281 else:
282 raise ValueError("Please specific the model dtype.")
283
284 if self.model is None:
285 self.model = AutoModelForCausalLM.from_pretrained(
286 config.model_name_or_path,
287 use_flash_attention=config.use_flash_attention,
288 dtype=dtype,
289 tensor_parallel_degree=self.tensor_parallel_degree,
290 tensor_parallel_rank=self.tensor_parallel_rank,
291 )
292
293 if config.lora_path is not None:
294 self.model = LoRAModel.from_pretrained(
295 model=self.model, lora_path=config.lora_path, lora_config=lora_config
296 )
297 if config.prefix_path is not None:
298 prefix_tuning_params = get_prefix_tuning_params(self.model)
299 self.model = PrefixModelForCausalLM.from_pretrained(
300 model=self.model,
301 prefix_path=config.prefix_path,
302 postprocess_past_key_value=prefix_tuning_params["postprocess_past_key_value"],
303 )
304 self.model.eval()
305
306 @paddle.no_grad()
307 def _infer(self, inputs: dict[str, paddle.Tensor]):
308 result = self.model.generate(
309 **inputs,
310 max_new_tokens=self.config.max_length,
311 bos_token_id=self.tokenizer.bos_token_id,
312 eos_token_id=get_eos_token_id(self.tokenizer, self.generation_config),
313 pad_token_id=self.tokenizer.pad_token_id,
314 decode_strategy=self.config.decode_strategy,
315 temperature=self.config.temperature,
316 top_k=self.config.top_k,
317 top_p=self.config.top_p,
318 repetition_penalty=self.config.repetition_penalty,
319 )
320 result = result[0]
321 return result
322
323 def stream_predict(self, inputs: dict[str, paddle.Tensor]):

Callers 1

create_predictorFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…