MCPcopy Index your code
hub / github.com/OpenPPL/ppq / check

Method check

ppq/quantization/optim/training.py:97–141  ·  view source on GitHub ↗

Check quantization error with a given dataloader with current checkpoints. Return whether quantization error is lower than before. Args: executor (BaseGraphExecutor): [description] dataloader (Iterable): [description] collate_fn (Callable): [descr

(self, executor: BaseGraphExecutor,
        dataloader: Iterable, collate_fn: Callable)

Source from the content-addressed store, hash-verified

95 self._verbose = verbose
96
97 def check(self, executor: BaseGraphExecutor,
98 dataloader: Iterable, collate_fn: Callable):
99 """Check quantization error with a given dataloader with current
100 checkpoints. Return whether quantization error is lower than before.
101
102 Args:
103 executor (BaseGraphExecutor): [description]
104 dataloader (Iterable): [description]
105 collate_fn (Callable): [description]
106
107 Returns:
108 [type]: [description]
109 """
110
111 # step - 1, collecting data
112 for data in dataloader:
113 if collate_fn is not None: data = collate_fn(data)
114 outputs = executor.forward(inputs=data, output_names=self._interested_outputs)
115 for name, output in zip(self._interested_outputs, outputs):
116 self._checkpoints[name].push(tensor=output, is_reference=False)
117
118 # step - 2, calculating loss
119 losses = []
120 for name in self._interested_outputs:
121 ckpt = self._checkpoints[name]
122 assert isinstance(ckpt, FinetuneCheckPoint)
123 qt_out, fp_out = ckpt.pop()
124 qt_out = torch.cat([tensor for tensor in qt_out])
125 fp_out = torch.cat([tensor for tensor in fp_out])
126 losses.append(self._loss_fn(y_pred=qt_out, y_real=fp_out).item())
127 ckpt.clear()
128
129 # step - 3, comparing loss
130 loss_now, loss_old = sum(losses), sum([ckpt.best_loss for ckpt in self._checkpoints.values()])
131 loss_now, loss_old = loss_now / len(losses), loss_old / len(losses)
132 if self._verbose: print(f'NOISE-SIGNAL RATIO: {loss_old * 100 :.4f}% -> {loss_now * 100:.4f}%.')
133
134 # if there is a loss drop, update all losses.
135 if loss_old > (loss_now * CHECKPOINT_TOLERANCE):
136 for idx, name in enumerate(self._interested_outputs):
137 ckpt = self._checkpoints[name]
138 assert isinstance(ckpt, FinetuneCheckPoint)
139 ckpt.best_loss = losses[idx]
140 return True
141 return False
142
143 def optimize(
144 self, graph: BaseGraph,

Callers 2

calib_blockMethod · 0.80

Calls 6

collate_fnFunction · 0.50
forwardMethod · 0.45
pushMethod · 0.45
popMethod · 0.45
appendMethod · 0.45
clearMethod · 0.45

Tested by

no test coverage detected