(self, graph: BaseGraph, dataloader: Iterable,
executor: BaseGraphExecutor, collate_fn: Callable,
**kwargs)
| 400 | self.name = 'PPQ Channelwise Split Pass' |
| 401 | |
| 402 | def optimize(self, graph: BaseGraph, dataloader: Iterable, |
| 403 | executor: BaseGraphExecutor, collate_fn: Callable, |
| 404 | **kwargs) -> None: |
| 405 | |
| 406 | interested_operations = [] |
| 407 | if self.interested_layers is None: |
| 408 | |
| 409 | for operation in graph.operations.values(): |
| 410 | if operation.type in EQUALIZATION_OPERATION_TYPE: |
| 411 | interested_operations.append(operation) |
| 412 | else: |
| 413 | |
| 414 | for name in self.interested_layers: |
| 415 | if name in graph.operations: |
| 416 | interested_operations.append(graph.operations[name]) |
| 417 | |
| 418 | pairs = self.find_equalization_pair( |
| 419 | graph=graph, interested_operations=interested_operations) |
| 420 | |
| 421 | print(f'{len(pairs)} equalization pair(s) was found, ready to run optimization.') |
| 422 | for iter_times in tqdm(range(self.iterations), desc='Layerwise Channel Split', total=self.iterations): |
| 423 | if self.including_act: |
| 424 | activations = self.collect_activations( |
| 425 | graph=graph, executor=executor, dataloader=dataloader, collate_fn=collate_fn, |
| 426 | operations=interested_operations) |
| 427 | |
| 428 | for name, act in activations.items(): |
| 429 | graph.variables[name].value = act # 将激活值写回网络 |
| 430 | |
| 431 | for equalization_pair in pairs: |
| 432 | |
| 433 | # can not split group convolution. |
| 434 | is_group_conv = False |
| 435 | for layer in equalization_pair.downstream_layers + equalization_pair.upstream_layers: |
| 436 | if layer.type in {'Conv', 'ConvTranspose'}: |
| 437 | group = layer.attributes.get('group', 1) |
| 438 | if group != 1: is_group_conv = True |
| 439 | |
| 440 | if is_group_conv: continue |
| 441 | equalization_pair.channel_split( |
| 442 | value_threshold=self.value_threshold, |
| 443 | including_bias=self.including_bias, |
| 444 | including_act=self.including_act, |
| 445 | bias_multiplier=self.bias_multiplier, |
| 446 | act_multiplier=self.act_multiplier) |
| 447 | |
| 448 | # channel split progress directly changes fp32 value of weight, |
| 449 | # store it for following procedure. |
| 450 | for op in graph.operations.values(): |
| 451 | if isinstance(op, QuantableOperation): |
| 452 | op.store_parameter_value() |
nothing calls this directly
no test coverage detected