(
self, graph: BaseGraph, dataloader: Iterable,
executor: BaseGraphExecutor, collate_fn: Callable, **kwargs
)
| 332 | |
| 333 | @ empty_ppq_cache |
| 334 | def optimize( |
| 335 | self, graph: BaseGraph, dataloader: Iterable, |
| 336 | executor: BaseGraphExecutor, collate_fn: Callable, **kwargs |
| 337 | ) -> None: |
| 338 | interested_operations = [] |
| 339 | |
| 340 | if self.interested_layers is None: |
| 341 | |
| 342 | for operation in graph.operations.values(): |
| 343 | if operation.type in EQUALIZATION_OPERATION_TYPE: |
| 344 | interested_operations.append(operation) |
| 345 | else: |
| 346 | |
| 347 | for name in self.interested_layers: |
| 348 | if name in graph.operations: |
| 349 | interested_operations.append(graph.operations[name]) |
| 350 | |
| 351 | pairs = self.find_equalization_pair( |
| 352 | graph=graph, interested_operations=interested_operations) |
| 353 | |
| 354 | if self.including_act: |
| 355 | activations = self.collect_activations( |
| 356 | graph=graph, executor=executor, dataloader=dataloader, collate_fn=collate_fn, |
| 357 | operations=interested_operations) |
| 358 | |
| 359 | for name, act in activations.items(): |
| 360 | graph.variables[name].value = act # 将激活值写回网络 |
| 361 | |
| 362 | print(f'{len(pairs)} equalization pair(s) was found, ready to run optimization.') |
| 363 | for iter_times in tqdm(range(self.iterations), desc='Layerwise Equalization', total=self.iterations): |
| 364 | for equalization_pair in pairs: |
| 365 | equalization_pair.equalize( |
| 366 | value_threshold=self.value_threshold, |
| 367 | including_bias=self.including_bias, |
| 368 | including_act=self.including_act, |
| 369 | bias_multiplier=self.bias_multiplier, |
| 370 | act_multiplier=self.act_multiplier) |
| 371 | |
| 372 | # equalization progress directly changes fp32 value of weight, |
| 373 | # store it for following procedure. |
| 374 | for op in graph.operations.values(): |
| 375 | if isinstance(op, QuantableOperation): |
| 376 | op.store_parameter_value() |
| 377 | |
| 378 | |
| 379 | class ChannelwiseSplitPass(LayerwiseEqualizationPass): |
nothing calls this directly
no test coverage detected