quanize data and return data dir
(llm_venv,
example_root,
model_dir,
dtype,
quantize_dir,
qformat="full_prec",
tp_size=1,
pp_size=1,
cp_size=1,
calib_size=512,
kv_cache_dtype=None,
**kwargs)
| 585 | |
| 586 | |
| 587 | def quantize_data(llm_venv, |
| 588 | example_root, |
| 589 | model_dir, |
| 590 | dtype, |
| 591 | quantize_dir, |
| 592 | qformat="full_prec", |
| 593 | tp_size=1, |
| 594 | pp_size=1, |
| 595 | cp_size=1, |
| 596 | calib_size=512, |
| 597 | kv_cache_dtype=None, |
| 598 | **kwargs): |
| 599 | "quanize data and return data dir" |
| 600 | model_name = os.path.basename(model_dir) |
| 601 | output_dir = os.path.join(quantize_dir, model_name, dtype, qformat, |
| 602 | f"tp{tp_size}pp{pp_size}") |
| 603 | if kv_cache_dtype: |
| 604 | output_dir = os.path.join(output_dir, kv_cache_dtype) |
| 605 | else: |
| 606 | output_dir = os.path.join(output_dir, "no_kv_cache") |
| 607 | |
| 608 | quantize_script = f"{example_root}/../../../quantization/quantize.py" if "core" in example_root else f"{example_root}/../quantization/quantize.py" |
| 609 | quantize_cmd = [ |
| 610 | quantize_script, |
| 611 | f"--model_dir={model_dir}", |
| 612 | f"--dtype={dtype}", |
| 613 | f"--qformat={qformat}", |
| 614 | f"--output_dir={output_dir}", |
| 615 | f"--tp_size={tp_size}", |
| 616 | f"--pp_size={pp_size}", |
| 617 | f"--cp_size={cp_size}", |
| 618 | f"--calib_size={calib_size}", |
| 619 | ] |
| 620 | |
| 621 | if kv_cache_dtype: |
| 622 | quantize_cmd.append(f"--kv_cache_dtype={kv_cache_dtype}") |
| 623 | timeout = kwargs.pop('timeout', None) |
| 624 | |
| 625 | for key, value in kwargs.items(): |
| 626 | if isinstance(value, bool): |
| 627 | if value: |
| 628 | quantize_cmd.append(f"--{key}") |
| 629 | else: |
| 630 | quantize_cmd.extend([f"--{key}", f"{value}"]) |
| 631 | |
| 632 | if llm_venv: |
| 633 | if not exists(output_dir): |
| 634 | venv_check_call(llm_venv, quantize_cmd, timeout=timeout) |
| 635 | return output_dir |
| 636 | else: |
| 637 | return quantize_cmd, output_dir |
| 638 | |
| 639 | |
| 640 | def find_tensorrt(ld_library_path): |