MCPcopy
hub / github.com/OpenPPL/ppq / quantize_torch_model

Function quantize_torch_model

ppq/api/interface.py:279–346  ·  view source on GitHub ↗

量化一个 Pytorch 原生的模型 输入一个 torch.nn.Module 返回一个量化后的 PPQ.IR.BaseGraph. quantize a pytorch model, input pytorch model and return quantized ppq IR graph Args: model (torch.nn.Module): 被量化的 torch 模型(torch.nn.Module) the pytorch model calib_dataloader (DataLoader): 校准数据集 calibr

(
    model: torch.nn.Module,
    calib_dataloader: DataLoader,
    calib_steps: int,
    input_shape: List[int],
    platform: TargetPlatform,
    input_dtype: torch.dtype = torch.float,
    setting: QuantizationSetting = None,
    collate_fn: Callable = None,
    inputs: List[Any] = None,
    do_quantize: bool = True,
    onnx_export_file: str = 'onnx.model',
    device: str = 'cuda',
    verbose: int = 0,
    )

Source from the content-addressed store, hash-verified

277
278@ empty_ppq_cache
279def quantize_torch_model(
280 model: torch.nn.Module,
281 calib_dataloader: DataLoader,
282 calib_steps: int,
283 input_shape: List[int],
284 platform: TargetPlatform,
285 input_dtype: torch.dtype = torch.float,
286 setting: QuantizationSetting = None,
287 collate_fn: Callable = None,
288 inputs: List[Any] = None,
289 do_quantize: bool = True,
290 onnx_export_file: str = 'onnx.model',
291 device: str = 'cuda',
292 verbose: int = 0,
293 ) -> BaseGraph:
294 """量化一个 Pytorch 原生的模型 输入一个 torch.nn.Module 返回一个量化后的 PPQ.IR.BaseGraph.
295
296 quantize a pytorch model, input pytorch model and return quantized ppq IR graph
297 Args:
298 model (torch.nn.Module): 被量化的 torch 模型(torch.nn.Module) the pytorch model
299
300 calib_dataloader (DataLoader): 校准数据集 calibration dataloader
301
302 calib_steps (int): 校准步数 calibration steps
303
304 collate_fn (Callable): 校准数据的预处理函数 batch collate func for preprocessing
305
306 input_shape (List[int]): 模型输入尺寸,用于执行 jit.trace,对于动态尺寸的模型,输入一个模型可接受的尺寸即可。
307 如果模型存在多个输入,则需要使用 inputs 变量进行传参,此项设置为 None
308 a list of ints indicating size of input, for multiple inputs, please use
309 keyword arg inputs for direct parameter passing and this should be set to None
310
311 input_dtype (torch.dtype): 模型输入数据类型,如果模型存在多个输入,则需要使用 inputs 变量进行传参,此项设置为 None
312 the torch datatype of input, for multiple inputs, please use keyword arg inputs
313 for direct parameter passing and this should be set to None
314
315 setting (OptimSetting): 量化配置信息,用于配置量化的各项参数,设置为 None 时加载默认参数。
316 Quantization setting, default setting will be used when set None
317
318 inputs (List[Any], optional): 对于存在多个输入的模型,在Inputs中直接指定一个输入List,从而完成模型的tracing。
319 for multiple inputs, please give the specified inputs directly in the form of
320 a list of arrays
321
322 do_quantize (Bool, optional): 是否执行量化 whether to quantize the model, defaults to True, defaults to True.
323
324 platform (TargetPlatform, optional): 量化的目标平台 target backend platform, defaults to TargetPlatform.DSP_INT8.
325
326 device (str, optional): 量化过程的执行设备 execution device, defaults to 'cuda'.
327
328 verbose (int, optional): 是否打印详细信息 whether to print details, defaults to 0.
329
330 Raises:
331 ValueError: 给定平台不可量化 the given platform doesn't support quantization
332 KeyError: 给定平台不被支持 the given platform is not supported yet
333
334 Returns:
335 BaseGraph: 量化后的IR,包含了后端量化所需的全部信息
336 The quantized IR, containing all information needed for backend execution

Callers 15

quantize_dsp.pyFile · 0.90
execute.pyFile · 0.90
finetune.pyFile · 0.90
optimization.pyFile · 0.90
targetPlatform.pyFile · 0.90
fusion.pyFile · 0.90
analyse.pyFile · 0.90
dequantize.pyFile · 0.90

Calls 2

dump_torch_to_onnxFunction · 0.85
quantize_onnx_modelFunction · 0.85

Tested by

no test coverage detected