Establish a series of network checkpoints with your network. Checkpoint is a data structure that helps us compare quant results and fp32 results. Args: graph (BaseGraph): [description] executor (BaseGraphExecutor): [description] datalo
(
self, graph: BaseGraph, executor: BaseGraphExecutor,
dataloader: Iterable, collate_fn: Callable)
| 44 | |
| 45 | @ empty_ppq_cache |
| 46 | def initialize_checkpoints( |
| 47 | self, graph: BaseGraph, executor: BaseGraphExecutor, |
| 48 | dataloader: Iterable, collate_fn: Callable): |
| 49 | """ |
| 50 | Establish a series of network checkpoints with your network. |
| 51 | Checkpoint is a data structure that helps us compare quant results and fp32 results. |
| 52 | Args: |
| 53 | graph (BaseGraph): [description] |
| 54 | executor (BaseGraphExecutor): [description] |
| 55 | dataloader (Iterable): [description] |
| 56 | collate_fn (Callable): [description] |
| 57 | |
| 58 | Raises: |
| 59 | PermissionError: [description] |
| 60 | """ |
| 61 | for operation in graph.operations.values(): |
| 62 | if isinstance(operation, QuantableOperation): |
| 63 | for cfg, var in operation.config_with_variable: |
| 64 | if cfg.state in {QuantizationStates.BAKED, QuantizationStates.PASSIVE_BAKED}: |
| 65 | raise PermissionError('Can not initialize checkpoints when weight value is baked. ' |
| 66 | f'Variable {var.name} has a baked value.') |
| 67 | |
| 68 | if self._interested_outputs is None or len(self._interested_outputs) == 0: |
| 69 | self._interested_outputs = [name for name in graph.outputs] |
| 70 | |
| 71 | for name in self._interested_outputs: |
| 72 | self._checkpoints[name] = FinetuneCheckPoint(variable=name) |
| 73 | |
| 74 | # dequantize graph, collect references |
| 75 | for op in graph.operations.values(): |
| 76 | if isinstance(op, QuantableOperation): |
| 77 | op.dequantize() |
| 78 | |
| 79 | for data in tqdm(dataloader, desc='Collecting Referecens'): |
| 80 | if collate_fn is not None: data = collate_fn(data) |
| 81 | outputs = executor.forward(inputs=data, output_names=self._interested_outputs) |
| 82 | for name, output in zip(self._interested_outputs, outputs): |
| 83 | ckpt = self._checkpoints[name] |
| 84 | assert isinstance(ckpt, FinetuneCheckPoint) |
| 85 | ckpt.push(tensor=output, is_reference=True) |
| 86 | |
| 87 | # restore quantization state: |
| 88 | for op in graph.operations.values(): |
| 89 | if isinstance(op, QuantableOperation): |
| 90 | op.restore_quantize_state() |
| 91 | |
| 92 | # update state |
| 93 | verbose, self._verbose = self._verbose, False |
| 94 | self.check(executor=executor, dataloader=dataloader, collate_fn=collate_fn) |
| 95 | self._verbose = verbose |
| 96 | |
| 97 | def check(self, executor: BaseGraphExecutor, |
| 98 | dataloader: Iterable, collate_fn: Callable): |
no test coverage detected