(self, config_path: str, graph: BaseGraph)
| 35 | |
| 36 | class CaffeExporter(GraphExporter): |
| 37 | def export_quantization_config(self, config_path: str, graph: BaseGraph): |
| 38 | var_quant_info_recorder, op_platform_recorder = {}, {} |
| 39 | for operation in graph.operations.values(): |
| 40 | if not isinstance(operation, QuantableOperation): continue |
| 41 | for config, var in operation.config_with_variable: |
| 42 | if not config.can_export(): continue |
| 43 | |
| 44 | # PATCH 2021.11.25 |
| 45 | # REMOVE BIAS FROM CONFIGURATION |
| 46 | if config.num_of_bits > 8: |
| 47 | continue |
| 48 | |
| 49 | if config.state in { |
| 50 | QuantizationStates.FP32, |
| 51 | }: continue |
| 52 | # Simply override recorder is acceptable here, |
| 53 | # we do not support mix precision quantization for CUDA backend now. |
| 54 | # All configurations for this variable should keep identical towards each other. |
| 55 | |
| 56 | if config.state == QuantizationStates.PASSIVE and var.name in var_quant_info_recorder: continue |
| 57 | var_quant_info_recorder[var.name] = config |
| 58 | |
| 59 | # ready to render config to json. |
| 60 | for var in var_quant_info_recorder: |
| 61 | config = var_quant_info_recorder[var] |
| 62 | assert isinstance(config, TensorQuantizationConfig) |
| 63 | tensorwise = config.policy.has_property(QuantizationProperty.PER_TENSOR) |
| 64 | var_quant_info_recorder[var] = { |
| 65 | 'bit_width' : config.num_of_bits, |
| 66 | 'per_channel': config.policy.has_property(QuantizationProperty.PER_CHANNEL), |
| 67 | 'quant_flag' : True, |
| 68 | 'sym' : config.policy.has_property(QuantizationProperty.SYMMETRICAL), |
| 69 | 'scale' : convert_value(config.scale, tensorwise, DataType.FP32), |
| 70 | 'zero_point' : convert_value(config.offset, tensorwise, DataType.INT32), |
| 71 | 'tensor_min' : convert_value(config.scale * (config.quant_min - config.offset), tensorwise, DataType.FP32), |
| 72 | 'tensor_max' : convert_value(config.scale * (config.quant_max - config.offset), tensorwise, DataType.FP32), |
| 73 | 'q_min' : config.quant_min, |
| 74 | 'q_max' : config.quant_max, |
| 75 | 'hash' : config.__hash__(), |
| 76 | 'dominator' : config.dominated_by.__hash__() |
| 77 | } |
| 78 | |
| 79 | for op in graph.operations.values(): |
| 80 | if convert_type(op.platform) is not None: |
| 81 | op_platform_recorder[op.name] = { |
| 82 | 'data_type': convert_type(op.platform) |
| 83 | } |
| 84 | |
| 85 | exports = { |
| 86 | 'quant_info': var_quant_info_recorder, |
| 87 | 'op_info': op_platform_recorder} |
| 88 | |
| 89 | with open(file=config_path, mode='w') as file: |
| 90 | json.dump(exports, file, indent=4) |
| 91 | |
| 92 | def prepare_model(self, graph: BaseGraph, input_shapes: List[List[int]]): |
| 93 | # trace model for exporting. |
no test coverage detected