(
self, config: PredictorArgument, model: PretrainedModel = None, tokenizer: PretrainedTokenizer = None
)
| 265 | |
| 266 | class DygraphPredictor(BasePredictor): |
| 267 | def __init__( |
| 268 | self, config: PredictorArgument, model: PretrainedModel = None, tokenizer: PretrainedTokenizer = None |
| 269 | ): |
| 270 | super().__init__(config, tokenizer) |
| 271 | self.model = model |
| 272 | if config.lora_path is not None: |
| 273 | lora_config = LoRAConfig.from_pretrained(config.lora_path) |
| 274 | dtype = lora_config.dtype |
| 275 | lora_config.merge_weights = True |
| 276 | elif config.prefix_path is not None: |
| 277 | prefix_config = PrefixConfig.from_pretrained(config.prefix_path) |
| 278 | dtype = prefix_config.dtype |
| 279 | elif config.dtype is not None: |
| 280 | dtype = config.dtype |
| 281 | else: |
| 282 | raise ValueError("Please specific the model dtype.") |
| 283 | |
| 284 | if self.model is None: |
| 285 | self.model = AutoModelForCausalLM.from_pretrained( |
| 286 | config.model_name_or_path, |
| 287 | use_flash_attention=config.use_flash_attention, |
| 288 | dtype=dtype, |
| 289 | tensor_parallel_degree=self.tensor_parallel_degree, |
| 290 | tensor_parallel_rank=self.tensor_parallel_rank, |
| 291 | ) |
| 292 | |
| 293 | if config.lora_path is not None: |
| 294 | self.model = LoRAModel.from_pretrained( |
| 295 | model=self.model, lora_path=config.lora_path, lora_config=lora_config |
| 296 | ) |
| 297 | if config.prefix_path is not None: |
| 298 | prefix_tuning_params = get_prefix_tuning_params(self.model) |
| 299 | self.model = PrefixModelForCausalLM.from_pretrained( |
| 300 | model=self.model, |
| 301 | prefix_path=config.prefix_path, |
| 302 | postprocess_past_key_value=prefix_tuning_params["postprocess_past_key_value"], |
| 303 | ) |
| 304 | self.model.eval() |
| 305 | |
| 306 | @paddle.no_grad() |
| 307 | def _infer(self, inputs: dict[str, paddle.Tensor]): |
nothing calls this directly
no test coverage detected