| 390 | op.restore_quantize_state(expire_device=None) |
| 391 | |
| 392 | def run_pair(self, |
| 393 | pair: List[Operation], |
| 394 | inputs: List[torch.Tensor], |
| 395 | hooks: Dict[Operation, CalibrationHook]={}) -> List[torch.Tensor]: |
| 396 | for op in pair: |
| 397 | inputs = inputs + [param.value for param in op.parameters] |
| 398 | if isinstance(op, QuantableOperation): |
| 399 | input_configs = [_ for _ in op.config.input_quantization_config] |
| 400 | assert(len(inputs) == len(input_configs)) |
| 401 | inputs_quant = [self.quant_func(input, config) for input, config in zip(inputs, input_configs)] |
| 402 | hook = hooks.get(op, None) |
| 403 | if hook is not None: |
| 404 | hook.pre_forward_hook(inputs, inputs_quant, input_configs) |
| 405 | else: |
| 406 | inputs_quant = inputs |
| 407 | |
| 408 | f = OPERATION_FORWARD_TABLE[op.platform][op.type] |
| 409 | outputs = f(op, inputs_quant) |
| 410 | outputs = outputs if isinstance(outputs, (list, tuple)) else [outputs] |
| 411 | |
| 412 | if isinstance(op, QuantableOperation): |
| 413 | output_configs = [_ for _ in op.config.output_quantization_config] |
| 414 | outputs_quant = [self.quant_func(output, config) for output, config in zip(outputs, output_configs)] |
| 415 | hook = hooks.get(op, None) |
| 416 | if hook is not None: |
| 417 | hook.post_forward_hook(outputs, outputs_quant, output_configs) |
| 418 | inputs = outputs_quant |
| 419 | else: |
| 420 | inputs = outputs |
| 421 | return inputs |
| 422 | |
| 423 | # mse calculation for a list of tensors |
| 424 | def calculate_mse(self, fp_res: List[torch.Tensor], quant_res: List[torch.Tensor]) -> torch.Tensor: |