Get the generate_stream function for inference.
(model: torch.nn.Module, model_path: str)
| 382 | |
| 383 | |
| 384 | def get_generate_stream_function(model: torch.nn.Module, model_path: str): |
| 385 | """Get the generate_stream function for inference.""" |
| 386 | from fastchat.serve.inference import generate_stream |
| 387 | |
| 388 | model_type = str(type(model)).lower() |
| 389 | is_peft = "peft" in model_type |
| 390 | is_chatglm = "chatglm" in model_type |
| 391 | is_falcon = "rwforcausallm" in model_type |
| 392 | is_codet5p = "codet5p" in model_type |
| 393 | is_exllama = "exllama" in model_type |
| 394 | is_xft = "xft" in model_type |
| 395 | is_yuan = "yuan" in model_type |
| 396 | |
| 397 | if is_chatglm: |
| 398 | return generate_stream_chatglm |
| 399 | elif is_falcon: |
| 400 | return generate_stream_falcon |
| 401 | elif is_codet5p: |
| 402 | return generate_stream_codet5p |
| 403 | elif is_exllama: |
| 404 | return generate_stream_exllama |
| 405 | elif is_xft: |
| 406 | return generate_stream_xft |
| 407 | elif is_yuan: |
| 408 | return generate_stream_yuan2 |
| 409 | |
| 410 | elif peft_share_base_weights and is_peft: |
| 411 | # Return a curried stream function that loads the right adapter |
| 412 | # according to the model_name available in this context. This ensures |
| 413 | # the right weights are available. |
| 414 | @torch.inference_mode() |
| 415 | def generate_stream_peft( |
| 416 | model, |
| 417 | tokenizer, |
| 418 | params: Dict, |
| 419 | device: str, |
| 420 | context_len: int, |
| 421 | stream_interval: int = 2, |
| 422 | judge_sent_end: bool = False, |
| 423 | ): |
| 424 | model.set_adapter(model_path) |
| 425 | base_model_type = str(type(model.base_model.model)) |
| 426 | is_chatglm = "chatglm" in base_model_type |
| 427 | is_falcon = "rwforcausallm" in base_model_type |
| 428 | is_codet5p = "codet5p" in base_model_type |
| 429 | is_exllama = "exllama" in base_model_type |
| 430 | is_xft = "xft" in base_model_type |
| 431 | is_yuan = "yuan" in base_model_type |
| 432 | |
| 433 | generate_stream_function = generate_stream |
| 434 | if is_chatglm: |
| 435 | generate_stream_function = generate_stream_chatglm |
| 436 | elif is_falcon: |
| 437 | generate_stream_function = generate_stream_falcon |
| 438 | elif is_codet5p: |
| 439 | generate_stream_function = generate_stream_codet5p |
| 440 | elif is_exllama: |
| 441 | generate_stream_function = generate_stream_exllama |