(model_path, exllama_config: ExllamaConfig)
| 17 | |
| 18 | |
| 19 | def load_exllama_model(model_path, exllama_config: ExllamaConfig): |
| 20 | try: |
| 21 | from exllamav2 import ( |
| 22 | ExLlamaV2Config, |
| 23 | ExLlamaV2Tokenizer, |
| 24 | ExLlamaV2, |
| 25 | ExLlamaV2Cache, |
| 26 | ExLlamaV2Cache_8bit, |
| 27 | ) |
| 28 | except ImportError as e: |
| 29 | print(f"Error: Failed to load Exllamav2. {e}") |
| 30 | sys.exit(-1) |
| 31 | |
| 32 | exllamav2_config = ExLlamaV2Config() |
| 33 | exllamav2_config.model_dir = model_path |
| 34 | exllamav2_config.prepare() |
| 35 | exllamav2_config.max_seq_len = exllama_config.max_seq_len |
| 36 | exllamav2_config.cache_8bit = exllama_config.cache_8bit |
| 37 | |
| 38 | exllama_model = ExLlamaV2(exllamav2_config) |
| 39 | tokenizer = ExLlamaV2Tokenizer(exllamav2_config) |
| 40 | |
| 41 | split = None |
| 42 | if exllama_config.gpu_split: |
| 43 | split = [float(alloc) for alloc in exllama_config.gpu_split.split(",")] |
| 44 | exllama_model.load(split) |
| 45 | |
| 46 | cache_class = ExLlamaV2Cache_8bit if exllamav2_config.cache_8bit else ExLlamaV2Cache |
| 47 | exllama_cache = cache_class(exllama_model) |
| 48 | model = ExllamaModel(exllama_model=exllama_model, exllama_cache=exllama_cache) |
| 49 | |
| 50 | return model, tokenizer |
no test coverage detected
searching dependent graphs…