(model_path, xft_config: XftConfig)
| 22 | |
| 23 | |
| 24 | def load_xft_model(model_path, xft_config: XftConfig): |
| 25 | try: |
| 26 | import xfastertransformer |
| 27 | from transformers import AutoTokenizer |
| 28 | except ImportError as e: |
| 29 | print(f"Error: Failed to load xFasterTransformer. {e}") |
| 30 | sys.exit(-1) |
| 31 | |
| 32 | if xft_config.data_type is None or xft_config.data_type == "": |
| 33 | data_type = "bf16_fp16" |
| 34 | else: |
| 35 | data_type = xft_config.data_type |
| 36 | tokenizer = AutoTokenizer.from_pretrained( |
| 37 | model_path, use_fast=False, padding_side="left", trust_remote_code=True |
| 38 | ) |
| 39 | xft_model = xfastertransformer.AutoModel.from_pretrained( |
| 40 | model_path, dtype=data_type |
| 41 | ) |
| 42 | model = XftModel(xft_model=xft_model, xft_config=xft_config) |
| 43 | if model.model.rank > 0: |
| 44 | while True: |
| 45 | model.model.generate() |
| 46 | return model, tokenizer |
no test coverage detected
searching dependent graphs…