(self, config_path: str, graph: BaseGraph)
| 18 | |
| 19 | class PPLBackendExporter(OnnxExporter): |
| 20 | def export_quantization_config(self, config_path: str, graph: BaseGraph): |
| 21 | var_quant_info_recorder, op_platform_recorder = {}, {} |
| 22 | for operation in graph.operations.values(): |
| 23 | if not isinstance(operation, QuantableOperation): continue |
| 24 | for config, var in operation.config_with_variable: |
| 25 | if not config.can_export(): continue |
| 26 | |
| 27 | # PATCH 2021.11.25 |
| 28 | # REMOVE BIAS FROM CONFIGURATION |
| 29 | if config.num_of_bits > 8: continue |
| 30 | |
| 31 | if config.state in { |
| 32 | QuantizationStates.FP32, |
| 33 | }: continue |
| 34 | # Simply override recorder is acceptable here, |
| 35 | # we do not support mix precision quantization for CUDA backend now. |
| 36 | # All configurations for this variable should keep identical towards each other. |
| 37 | |
| 38 | if config.state == QuantizationStates.PASSIVE and var.name in var_quant_info_recorder: continue |
| 39 | var_quant_info_recorder[var.name] = config |
| 40 | |
| 41 | # ready to render config to json. |
| 42 | for var in var_quant_info_recorder: |
| 43 | config = var_quant_info_recorder[var] |
| 44 | assert isinstance(config, TensorQuantizationConfig) |
| 45 | tensorwise = config.policy.has_property(QuantizationProperty.PER_TENSOR) |
| 46 | var_quant_info_recorder[var] = { |
| 47 | 'bit_width' : config.num_of_bits, |
| 48 | 'per_channel': config.policy.has_property(QuantizationProperty.PER_CHANNEL), |
| 49 | 'quant_flag' : True, |
| 50 | 'sym' : config.policy.has_property(QuantizationProperty.SYMMETRICAL), |
| 51 | 'scale' : convert_value(config.scale, tensorwise, DataType.FP32), |
| 52 | 'zero_point' : convert_value(config.offset, tensorwise, DataType.INT32), |
| 53 | 'tensor_min' : convert_value(config.scale * (config.quant_min - config.offset), tensorwise, DataType.FP32), |
| 54 | 'tensor_max' : convert_value(config.scale * (config.quant_max - config.offset), tensorwise, DataType.FP32), |
| 55 | 'q_min' : config.quant_min, |
| 56 | 'q_max' : config.quant_max, |
| 57 | 'hash' : config.__hash__(), |
| 58 | 'dominator' : config.dominated_by.__hash__() |
| 59 | } |
| 60 | |
| 61 | for op in graph.operations.values(): |
| 62 | if convert_type(op.platform) is not None: |
| 63 | op_platform_recorder[op.name] = { |
| 64 | 'data_type': convert_type(op.platform) |
| 65 | } |
| 66 | |
| 67 | exports = { |
| 68 | 'quant_info': var_quant_info_recorder, |
| 69 | 'op_info': op_platform_recorder} |
| 70 | |
| 71 | with open(file=config_path, mode='w') as file: |
| 72 | json.dump(exports, file, indent=4) |
no test coverage detected