MCPcopy
hub / github.com/OpenPPL/ppq / calib_block

Method calib_block

ppq/quantization/optim/exprimental.py:51–97  ·  view source on GitHub ↗
(self, quant_inputs: List[torch.Tensor], fp32_outputs: List[torch.Tensor],
        executor: TorchExecutor, block: TrainableBlock, dataloader: Iterable, collate_fn: Callable)

Source from the content-addressed store, hash-verified

49
50 @ torch.no_grad()
51 def calib_block(self, quant_inputs: List[torch.Tensor], fp32_outputs: List[torch.Tensor],
52 executor: TorchExecutor, block: TrainableBlock, dataloader: Iterable, collate_fn: Callable):
53
54 # create trainable delegators for each parameter.
55 delegators = []
56 for operation in block.rps:
57 if isinstance(operation, QuantableOperation):
58 for cfg, var in operation.config_with_variable:
59 if cfg.state == QuantizationStates.ACTIVATED:
60 delegators.append(BanditDelegator(arms=self.arms, config=cfg))
61 delegators = [d for d in delegators if isinstance(d, BanditDelegator)]
62 dataset = RandomMemDataset(data=[[qt, fp] for qt, fp in zip(quant_inputs, fp32_outputs)])
63 device = executor._executing_context.executing_device
64 loss_ema = EMARecorder(beta=0.98)
65 output_var = block.ep.outputs[0]
66 input_var = block.sp.inputs[0]
67
68 for delegator in delegators:
69 executor.register_quantize_delegate(config=delegator.config, delegator=delegator)
70
71 cur_iter = 0
72 with tqdm(total=self.target_step) as t:
73 while cur_iter < self.target_step:
74 qt_input, fp_output = dataset.pop()
75 qt_input, fp_output = qt_input.to(device), fp_output.to(device)
76
77 qt_output = executor.partial_graph_forward(
78 operations=block.rps, feed_dict={input_var.name: qt_input},
79 output_names=[output_var.name])[0]
80
81 loss = torch_snr_error(y_pred=qt_output, y_real=fp_output).item()
82 for delegator in delegators: delegator.mark(1 - loss)
83 loss_ema.push(loss)
84
85 cur_iter += 1
86 if cur_iter % 50 == 0:
87 t.set_description(desc=f'Block [{self._bidx + 1}/{self._num_of_blocks}]')
88 t.set_postfix(loss = loss_ema.pop())
89 t.update(50)
90
91 for delegator in delegators:
92 executor.remove_quantize_delegate(config=delegator.config)
93 delegator.finalize()
94
95 if not self.check(executor=executor, dataloader=dataloader, collate_fn=collate_fn):
96 for delegator in delegators:
97 delegator.withdraw()
98
99 def collect_training_data(
100 self, output_name: str,

Callers 1

optimizeMethod · 0.95

Calls 15

popMethod · 0.95
pushMethod · 0.95
popMethod · 0.95
BanditDelegatorClass · 0.90
EMARecorderClass · 0.90
torch_snr_errorFunction · 0.90
RandomMemDatasetClass · 0.85
toMethod · 0.80
partial_graph_forwardMethod · 0.80
markMethod · 0.80
updateMethod · 0.80

Tested by

no test coverage detected