(self,
pair: List[Operation],
executor: BaseGraphExecutor,
data_loader: Iterable,
collate_fn: Callable,
calib_steps: int
)
| 429 | |
| 430 | @torch.no_grad() |
| 431 | def test_ssd_loss(self, |
| 432 | pair: List[Operation], |
| 433 | executor: BaseGraphExecutor, |
| 434 | data_loader: Iterable, |
| 435 | collate_fn: Callable, |
| 436 | calib_steps: int |
| 437 | ) -> float: |
| 438 | observers = self.build_observer_pair(pair) |
| 439 | hooks = {op:observers[op].hook for op in observers} |
| 440 | self.calibrate(pair, data_loader, executor, hooks, collate_fn, calib_steps) |
| 441 | for _, observer in observers.items(): |
| 442 | observer.render_quantization_config() |
| 443 | pop_list = [] |
| 444 | for op, observer in observers.items(): |
| 445 | if all([type(var_observer) not in {TorchHistObserver} |
| 446 | for var_observer in observer._hook._observer_table.values()]): |
| 447 | pop_list.append(op) |
| 448 | for op in pop_list: |
| 449 | observers.pop(op) |
| 450 | hooks.pop(op) |
| 451 | if len(hooks) > 0: |
| 452 | self.calibrate(pair, data_loader, executor, hooks, collate_fn, calib_steps) |
| 453 | for _, observer in observers.items(): |
| 454 | observer.render_quantization_config() |
| 455 | self.calibration_passive_param(pair) |
| 456 | # calculate loss |
| 457 | loss = [] |
| 458 | for calib_epoch in range(ceil(calib_steps / len(data_loader))): |
| 459 | for _,data in enumerate(data_loader): |
| 460 | if collate_fn is not None: |
| 461 | data = collate_fn(data) |
| 462 | # get the input of first op |
| 463 | inputs = executor.forward(data, output_names=[pair[0].inputs[0].name]) |
| 464 | # dequant to get fp output |
| 465 | self.dequantize_pair(pair) |
| 466 | fp_output = self.run_pair(pair, inputs) |
| 467 | # restore quant state to get quant output |
| 468 | self.restore_quantize_state(pair) |
| 469 | quant_output = self.run_pair(pair, inputs) |
| 470 | # mse calculation |
| 471 | loss.append(self.calculate_mse(fp_output, quant_output)) |
| 472 | return torch.stack(loss).mean().item() |
| 473 | |
| 474 | # maintain original parameter for restoration in case of a larger loss after equalization |
| 475 | def collect_original_parameter(self, pair: List[Operation]) -> Dict[Variable, torch.Tensor]: |
no test coverage detected