Generate cli common params string by run config dict.
(run_config: dict[str, Any])
| 557 | |
| 558 | |
| 559 | def get_cli_common_param(run_config: dict[str, Any]) -> str: |
| 560 | """Generate cli common params string by run config dict.""" |
| 561 | backend = run_config.get('backend') |
| 562 | model = run_config.get('model') |
| 563 | communicator = run_config.get('communicator') |
| 564 | quant_policy = run_config.get('quant_policy') |
| 565 | extra_params = run_config.get('extra_params', {}) |
| 566 | parallel_config = run_config.get('parallel_config', {}) |
| 567 | |
| 568 | cli_params = [f'--backend {backend}', f'--communicator {communicator}'] |
| 569 | # Optional params |
| 570 | if quant_policy != 0: |
| 571 | cli_params.append(f'--quant-policy {quant_policy}') |
| 572 | |
| 573 | # quant format |
| 574 | model_lower = model.lower() |
| 575 | if 'w4' in model_lower or '4bits' in model_lower or 'awq' in model_lower: |
| 576 | cli_params.append('--model-format awq') |
| 577 | if 'gptq' in model_lower: |
| 578 | cli_params.append('--model-format gptq') |
| 579 | |
| 580 | # Parallel config |
| 581 | for para_key in ('dp', 'ep', 'cp'): |
| 582 | if para_key in parallel_config and parallel_config[para_key] > 1: |
| 583 | cli_params.append(f'--{para_key} {parallel_config[para_key]}') |
| 584 | if 'tp' in parallel_config and parallel_config['tp'] > 1: |
| 585 | tp_num = parallel_config['tp'] |
| 586 | cli_params.append(f'--tp {tp_num}') # noqa |
| 587 | |
| 588 | # Extra params |
| 589 | if len(extra_params) > 0: |
| 590 | cli_params.append(get_cli_str(extra_params)) |
| 591 | cli_params.append('--trust-remote-code') |
| 592 | |
| 593 | return ' '.join(cli_params).strip() |
| 594 | |
| 595 | |
| 596 | def get_cli_str(config: dict[str, Any]) -> str: |
no test coverage detected