MCPcopy
hub / github.com/OpenPPL/ppq / test_ssd_loss

Method test_ssd_loss

ppq/quantization/optim/ssd.py:431–472  ·  view source on GitHub ↗
(self,
                    pair: List[Operation],
                    executor: BaseGraphExecutor,
                    data_loader: Iterable,
                    collate_fn: Callable,
                    calib_steps: int
    )

Source from the content-addressed store, hash-verified

429
430 @torch.no_grad()
431 def test_ssd_loss(self,
432 pair: List[Operation],
433 executor: BaseGraphExecutor,
434 data_loader: Iterable,
435 collate_fn: Callable,
436 calib_steps: int
437 ) -> float:
438 observers = self.build_observer_pair(pair)
439 hooks = {op:observers[op].hook for op in observers}
440 self.calibrate(pair, data_loader, executor, hooks, collate_fn, calib_steps)
441 for _, observer in observers.items():
442 observer.render_quantization_config()
443 pop_list = []
444 for op, observer in observers.items():
445 if all([type(var_observer) not in {TorchHistObserver}
446 for var_observer in observer._hook._observer_table.values()]):
447 pop_list.append(op)
448 for op in pop_list:
449 observers.pop(op)
450 hooks.pop(op)
451 if len(hooks) > 0:
452 self.calibrate(pair, data_loader, executor, hooks, collate_fn, calib_steps)
453 for _, observer in observers.items():
454 observer.render_quantization_config()
455 self.calibration_passive_param(pair)
456 # calculate loss
457 loss = []
458 for calib_epoch in range(ceil(calib_steps / len(data_loader))):
459 for _,data in enumerate(data_loader):
460 if collate_fn is not None:
461 data = collate_fn(data)
462 # get the input of first op
463 inputs = executor.forward(data, output_names=[pair[0].inputs[0].name])
464 # dequant to get fp output
465 self.dequantize_pair(pair)
466 fp_output = self.run_pair(pair, inputs)
467 # restore quant state to get quant output
468 self.restore_quantize_state(pair)
469 quant_output = self.run_pair(pair, inputs)
470 # mse calculation
471 loss.append(self.calculate_mse(fp_output, quant_output))
472 return torch.stack(loss).mean().item()
473
474 # maintain original parameter for restoration in case of a larger loss after equalization
475 def collect_original_parameter(self, pair: List[Operation]) -> Dict[Variable, torch.Tensor]:

Callers 1

optimizeMethod · 0.95

Calls 12

build_observer_pairMethod · 0.95
calibrateMethod · 0.95
dequantize_pairMethod · 0.95
run_pairMethod · 0.95
calculate_mseMethod · 0.95
collate_fnFunction · 0.50
appendMethod · 0.45
popMethod · 0.45
forwardMethod · 0.45

Tested by

no test coverage detected