MCPcopy
hub / github.com/OpenPPL/ppq / initialize_checkpoints

Method initialize_checkpoints

ppq/quantization/optim/training.py:46–95  ·  view source on GitHub ↗

Establish a series of network checkpoints with your network. Checkpoint is a data structure that helps us compare quant results and fp32 results. Args: graph (BaseGraph): [description] executor (BaseGraphExecutor): [description] datalo

(
        self, graph: BaseGraph, executor: BaseGraphExecutor,
        dataloader: Iterable, collate_fn: Callable)

Source from the content-addressed store, hash-verified

44
45 @ empty_ppq_cache
46 def initialize_checkpoints(
47 self, graph: BaseGraph, executor: BaseGraphExecutor,
48 dataloader: Iterable, collate_fn: Callable):
49 """
50 Establish a series of network checkpoints with your network.
51 Checkpoint is a data structure that helps us compare quant results and fp32 results.
52 Args:
53 graph (BaseGraph): [description]
54 executor (BaseGraphExecutor): [description]
55 dataloader (Iterable): [description]
56 collate_fn (Callable): [description]
57
58 Raises:
59 PermissionError: [description]
60 """
61 for operation in graph.operations.values():
62 if isinstance(operation, QuantableOperation):
63 for cfg, var in operation.config_with_variable:
64 if cfg.state in {QuantizationStates.BAKED, QuantizationStates.PASSIVE_BAKED}:
65 raise PermissionError('Can not initialize checkpoints when weight value is baked. '
66 f'Variable {var.name} has a baked value.')
67
68 if self._interested_outputs is None or len(self._interested_outputs) == 0:
69 self._interested_outputs = [name for name in graph.outputs]
70
71 for name in self._interested_outputs:
72 self._checkpoints[name] = FinetuneCheckPoint(variable=name)
73
74 # dequantize graph, collect references
75 for op in graph.operations.values():
76 if isinstance(op, QuantableOperation):
77 op.dequantize()
78
79 for data in tqdm(dataloader, desc='Collecting Referecens'):
80 if collate_fn is not None: data = collate_fn(data)
81 outputs = executor.forward(inputs=data, output_names=self._interested_outputs)
82 for name, output in zip(self._interested_outputs, outputs):
83 ckpt = self._checkpoints[name]
84 assert isinstance(ckpt, FinetuneCheckPoint)
85 ckpt.push(tensor=output, is_reference=True)
86
87 # restore quantization state:
88 for op in graph.operations.values():
89 if isinstance(op, QuantableOperation):
90 op.restore_quantize_state()
91
92 # update state
93 verbose, self._verbose = self._verbose, False
94 self.check(executor=executor, dataloader=dataloader, collate_fn=collate_fn)
95 self._verbose = verbose
96
97 def check(self, executor: BaseGraphExecutor,
98 dataloader: Iterable, collate_fn: Callable):

Callers 1

optimizeMethod · 0.80

Calls 7

checkMethod · 0.95
FinetuneCheckPointClass · 0.85
dequantizeMethod · 0.80
collate_fnFunction · 0.50
forwardMethod · 0.45
pushMethod · 0.45

Tested by

no test coverage detected