| 106 | self._override = override |
| 107 | |
| 108 | def calibrate(self, desc: str, dataloader: Iterable, executor: BaseGraphExecutor, |
| 109 | hooks:Dict[str, RuntimeHook], output_names: List[str] = None): |
| 110 | |
| 111 | calib_step = 0 |
| 112 | with tqdm(total=self._calib_steps, desc=desc) as progressing_bar: |
| 113 | for calib_epoch in range(ceil(self._calib_steps / len(dataloader))): |
| 114 | for data in dataloader: |
| 115 | if self._collate_fn is not None: |
| 116 | data = self._collate_fn(data) |
| 117 | executor.forward(inputs=data, hooks=hooks, |
| 118 | output_names=output_names) |
| 119 | progressing_bar.update() |
| 120 | calib_step += 1 |
| 121 | if calib_step >= self._calib_steps: break |
| 122 | |
| 123 | @ empty_ppq_cache |
| 124 | def optimize( |