(model_path, backend, trust_remote_code=False, **kwargs)
| 15 | |
| 16 | |
| 17 | def build_pipe(model_path, backend, trust_remote_code=False, **kwargs): |
| 18 | engine_config = None |
| 19 | if kwargs.get('enable_prefix_caching', False): |
| 20 | print('interactive chat cannot be used when prefix caching is enabled') |
| 21 | exit(-1) |
| 22 | if backend == 'turbomind': |
| 23 | engine_config = TurbomindEngineConfig() |
| 24 | for key, value in kwargs.items(): |
| 25 | if hasattr(TurbomindEngineConfig, key): |
| 26 | setattr(engine_config, key, value) |
| 27 | else: |
| 28 | engine_config = PytorchEngineConfig() |
| 29 | for key, value in kwargs.items(): |
| 30 | key = 'device_type' if key == 'device' else key |
| 31 | if hasattr(PytorchEngineConfig, key): |
| 32 | setattr(engine_config, key, value) |
| 33 | if kwargs.get('adapters', None): |
| 34 | from .utils import get_lora_adapters |
| 35 | adapters = get_lora_adapters(kwargs['adapters']) |
| 36 | engine_config.adapters = adapters |
| 37 | # disable metrics to avoid installing prometheus_client, which is not needed |
| 38 | # in interactive chat |
| 39 | engine_config.enable_metrics = False |
| 40 | |
| 41 | # set chat template config |
| 42 | chat_template = kwargs.get('chat_template', None) |
| 43 | chat_template_config = None |
| 44 | if chat_template: |
| 45 | from .utils import get_chat_template |
| 46 | chat_template_config = get_chat_template(chat_template, model_path) |
| 47 | pipe = pipeline(model_path, |
| 48 | backend_config=engine_config, |
| 49 | chat_template_config=chat_template_config, |
| 50 | log_level='ERROR', |
| 51 | trust_remote_code=trust_remote_code, |
| 52 | **kwargs) |
| 53 | return pipe |
| 54 | |
| 55 | |
| 56 | def build_gen_config(**kwargs): |
no test coverage detected