Check quantization error with a given dataloader with current checkpoints. Return whether quantization error is lower than before. Args: executor (BaseGraphExecutor): [description] dataloader (Iterable): [description] collate_fn (Callable): [descr
(self, executor: BaseGraphExecutor,
dataloader: Iterable, collate_fn: Callable)
| 95 | self._verbose = verbose |
| 96 | |
| 97 | def check(self, executor: BaseGraphExecutor, |
| 98 | dataloader: Iterable, collate_fn: Callable): |
| 99 | """Check quantization error with a given dataloader with current |
| 100 | checkpoints. Return whether quantization error is lower than before. |
| 101 | |
| 102 | Args: |
| 103 | executor (BaseGraphExecutor): [description] |
| 104 | dataloader (Iterable): [description] |
| 105 | collate_fn (Callable): [description] |
| 106 | |
| 107 | Returns: |
| 108 | [type]: [description] |
| 109 | """ |
| 110 | |
| 111 | # step - 1, collecting data |
| 112 | for data in dataloader: |
| 113 | if collate_fn is not None: data = collate_fn(data) |
| 114 | outputs = executor.forward(inputs=data, output_names=self._interested_outputs) |
| 115 | for name, output in zip(self._interested_outputs, outputs): |
| 116 | self._checkpoints[name].push(tensor=output, is_reference=False) |
| 117 | |
| 118 | # step - 2, calculating loss |
| 119 | losses = [] |
| 120 | for name in self._interested_outputs: |
| 121 | ckpt = self._checkpoints[name] |
| 122 | assert isinstance(ckpt, FinetuneCheckPoint) |
| 123 | qt_out, fp_out = ckpt.pop() |
| 124 | qt_out = torch.cat([tensor for tensor in qt_out]) |
| 125 | fp_out = torch.cat([tensor for tensor in fp_out]) |
| 126 | losses.append(self._loss_fn(y_pred=qt_out, y_real=fp_out).item()) |
| 127 | ckpt.clear() |
| 128 | |
| 129 | # step - 3, comparing loss |
| 130 | loss_now, loss_old = sum(losses), sum([ckpt.best_loss for ckpt in self._checkpoints.values()]) |
| 131 | loss_now, loss_old = loss_now / len(losses), loss_old / len(losses) |
| 132 | if self._verbose: print(f'NOISE-SIGNAL RATIO: {loss_old * 100 :.4f}% -> {loss_now * 100:.4f}%.') |
| 133 | |
| 134 | # if there is a loss drop, update all losses. |
| 135 | if loss_old > (loss_now * CHECKPOINT_TOLERANCE): |
| 136 | for idx, name in enumerate(self._interested_outputs): |
| 137 | ckpt = self._checkpoints[name] |
| 138 | assert isinstance(ckpt, FinetuneCheckPoint) |
| 139 | ckpt.best_loss = losses[idx] |
| 140 | return True |
| 141 | return False |
| 142 | |
| 143 | def optimize( |
| 144 | self, graph: BaseGraph, |