(self, quant_inputs: List[torch.Tensor], fp32_outputs: List[torch.Tensor],
executor: TorchExecutor, block: TrainableBlock, dataloader: Iterable, collate_fn: Callable)
| 49 | |
| 50 | @ torch.no_grad() |
| 51 | def calib_block(self, quant_inputs: List[torch.Tensor], fp32_outputs: List[torch.Tensor], |
| 52 | executor: TorchExecutor, block: TrainableBlock, dataloader: Iterable, collate_fn: Callable): |
| 53 | |
| 54 | # create trainable delegators for each parameter. |
| 55 | delegators = [] |
| 56 | for operation in block.rps: |
| 57 | if isinstance(operation, QuantableOperation): |
| 58 | for cfg, var in operation.config_with_variable: |
| 59 | if cfg.state == QuantizationStates.ACTIVATED: |
| 60 | delegators.append(BanditDelegator(arms=self.arms, config=cfg)) |
| 61 | delegators = [d for d in delegators if isinstance(d, BanditDelegator)] |
| 62 | dataset = RandomMemDataset(data=[[qt, fp] for qt, fp in zip(quant_inputs, fp32_outputs)]) |
| 63 | device = executor._executing_context.executing_device |
| 64 | loss_ema = EMARecorder(beta=0.98) |
| 65 | output_var = block.ep.outputs[0] |
| 66 | input_var = block.sp.inputs[0] |
| 67 | |
| 68 | for delegator in delegators: |
| 69 | executor.register_quantize_delegate(config=delegator.config, delegator=delegator) |
| 70 | |
| 71 | cur_iter = 0 |
| 72 | with tqdm(total=self.target_step) as t: |
| 73 | while cur_iter < self.target_step: |
| 74 | qt_input, fp_output = dataset.pop() |
| 75 | qt_input, fp_output = qt_input.to(device), fp_output.to(device) |
| 76 | |
| 77 | qt_output = executor.partial_graph_forward( |
| 78 | operations=block.rps, feed_dict={input_var.name: qt_input}, |
| 79 | output_names=[output_var.name])[0] |
| 80 | |
| 81 | loss = torch_snr_error(y_pred=qt_output, y_real=fp_output).item() |
| 82 | for delegator in delegators: delegator.mark(1 - loss) |
| 83 | loss_ema.push(loss) |
| 84 | |
| 85 | cur_iter += 1 |
| 86 | if cur_iter % 50 == 0: |
| 87 | t.set_description(desc=f'Block [{self._bidx + 1}/{self._num_of_blocks}]') |
| 88 | t.set_postfix(loss = loss_ema.pop()) |
| 89 | t.update(50) |
| 90 | |
| 91 | for delegator in delegators: |
| 92 | executor.remove_quantize_delegate(config=delegator.config) |
| 93 | delegator.finalize() |
| 94 | |
| 95 | if not self.check(executor=executor, dataloader=dataloader, collate_fn=collate_fn): |
| 96 | for delegator in delegators: |
| 97 | delegator.withdraw() |
| 98 | |
| 99 | def collect_training_data( |
| 100 | self, output_name: str, |
no test coverage detected