MCPcopy
hub / github.com/InternLM/lmdeploy / smooth_quant

Function smooth_quant

lmdeploy/lite/apis/smooth_quant.py:18–142  ·  view source on GitHub ↗
(model: str,
                 work_dir: str = './work_dir',
                 calib_dataset: str = 'wikitext2',
                 calib_samples: int = 128,
                 calib_seqlen: int = 2048,
                 search_scale: bool = False,
                 batch_size: int = 1,
                 w_bits: int = 8,
                 dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto',
                 device: str = 'cuda',
                 quant_dtype: Literal['int8', 'fp8', 'float8_e4m3fn', 'float8_e5m2'] = 'int8',
                 revision: str = None,
                 download_dir: str = None,
                 trust_remote_code: bool = False)

Source from the content-addressed store, hash-verified

16
17
18def smooth_quant(model: str,
19 work_dir: str = './work_dir',
20 calib_dataset: str = 'wikitext2',
21 calib_samples: int = 128,
22 calib_seqlen: int = 2048,
23 search_scale: bool = False,
24 batch_size: int = 1,
25 w_bits: int = 8,
26 dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto',
27 device: str = 'cuda',
28 quant_dtype: Literal['int8', 'fp8', 'float8_e4m3fn', 'float8_e5m2'] = 'int8',
29 revision: str = None,
30 download_dir: str = None,
31 trust_remote_code: bool = False):
32 try_import_deeplink(device)
33 if quant_dtype == 'fp8':
34 quant_dtype = 'float8_e4m3fn'
35
36 quant_dtype = getattr(torch, quant_dtype, torch.int8)
37 if quant_dtype.is_floating_point:
38 q_dtype_info = torch.finfo(quant_dtype)
39 else:
40 q_dtype_info = torch.iinfo(quant_dtype)
41
42 assert q_dtype_info.bits == w_bits
43 if not osp.exists(model):
44 print(f'can\'t find model from local_path {model}, '
45 'try to download from remote')
46 from lmdeploy.utils import get_model
47 model = get_model(model, revision=revision, download_dir=download_dir)
48 model_path = model
49 arch, vl_model, model, tokenizer, work_dir = calibrate(model,
50 calib_dataset,
51 calib_samples,
52 calib_seqlen,
53 work_dir,
54 device,
55 w_bits=w_bits,
56 w_group_size=-1,
57 search_scale=search_scale,
58 dtype=dtype,
59 batch_size=batch_size,
60 trust_remote_code=trust_remote_code)
61
62 # calibrate function exports the calibration statistics
63 # (inputs, outputs, keys and values) to `work_dir`.
64 inp_stats = torch.load(work_dir / 'inputs_stats.pth', weights_only=True)
65 act_scales = inp_stats['absmax']
66
67 model_type = type(model).__name__
68 if model_type not in LAYER_TYPE_MAP or model_type not in NORM_TYPE_MAP:
69 raise RuntimeError(f'Currently, quantification and calibration of {model_type} are '
70 f'not supported. The supported model types are '
71 f"{', '.join(LAYER_TYPE_MAP.keys())}.")
72
73 if model_type == 'QWenLMHeadModel':
74 try:
75 import flash_attn # noqa: F401

Callers 1

smooth_quantMethod · 0.90

Calls 15

try_import_deeplinkFunction · 0.90
get_modelFunction · 0.90
calibrateFunction · 0.90
collect_target_modulesFunction · 0.90
awq_layersFunction · 0.90
smooth_layersFunction · 0.90
skipped_moduleFunction · 0.90
save_vl_modelFunction · 0.85
joinMethod · 0.80
itemsMethod · 0.80
extendMethod · 0.80
loadMethod · 0.45

Tested by

no test coverage detected