(self)
| 515 | |
| 516 | |
| 517 | def fuse_bn(self): |
| 518 | search_engine = SearchableGraph(graph=self.graph) |
| 519 | paths = search_engine.path_matching( |
| 520 | sp_expr=lambda x: x.type in {'Conv', 'Gemm', 'ConvTranspose'}, |
| 521 | rp_expr=lambda x, y: False, |
| 522 | ep_expr=lambda x: x.type == 'BatchNormalization', |
| 523 | direction='down') |
| 524 | |
| 525 | for path in paths: |
| 526 | path = path.tolist() |
| 527 | assert len(path) == 2, ('Oops seems we got something unexpected.') |
| 528 | |
| 529 | computing_op, bn_op = path |
| 530 | assert isinstance(computing_op, Operation) and isinstance(bn_op, Operation) |
| 531 | |
| 532 | if (len(self.graph.get_downstream_operations(computing_op)) != 1 or |
| 533 | len(self.graph.get_upstream_operations(bn_op)) != 1): |
| 534 | ppq_warning(f'PPQ can not merge operation {computing_op.name} and {bn_op.name}, ' |
| 535 | 'this is not suppose to happen with your network, ' |
| 536 | 'network with batchnorm inside might not be able to quantize and deploy.') |
| 537 | continue |
| 538 | |
| 539 | assert len(bn_op.parameters) == 4, 'BatchNorm should have 4 parameters, namely alpha, beta, mean, var' |
| 540 | alpha = bn_op.parameters[0].value |
| 541 | beta = bn_op.parameters[1].value |
| 542 | mean = bn_op.parameters[2].value |
| 543 | var = bn_op.parameters[3].value |
| 544 | epsilon = bn_op.attributes.get('epsilon', 1e-5) |
| 545 | |
| 546 | if computing_op.num_of_parameter == 1: |
| 547 | w = computing_op.parameters[0].value # no bias. |
| 548 | assert isinstance(w, torch.Tensor), 'values of parameters are assumed as torch Tensor' |
| 549 | if computing_op.type == 'ConvTranspose': |
| 550 | b = torch.zeros(w.shape[1] * computing_op.attributes.get('group', 1)) |
| 551 | elif computing_op.type == 'Gemm' and computing_op.attributes.get('transB', 0) == 0: |
| 552 | b = torch.zeros(w.shape[1]) |
| 553 | else: |
| 554 | b = torch.zeros(w.shape[0]) |
| 555 | else: |
| 556 | w, b = [var.value for var in computing_op.parameters[: 2]] # has bias. |
| 557 | |
| 558 | if computing_op.type == 'Conv': |
| 559 | |
| 560 | # calculate new weight and bias |
| 561 | scale = alpha / torch.sqrt(var + epsilon) |
| 562 | w = w * scale.reshape([-1] + [1] * (w.ndim - 1)) |
| 563 | b = alpha * (b - mean) / torch.sqrt(var + epsilon) + beta |
| 564 | |
| 565 | elif computing_op.type == 'Gemm': |
| 566 | |
| 567 | # calculate new weight and bias |
| 568 | scale = alpha / torch.sqrt(var + epsilon) |
| 569 | if computing_op.attributes.get('transB', 0): |
| 570 | w = w * scale.reshape([-1, 1]) |
| 571 | else: |
| 572 | w = w * scale.reshape([1, -1]) |
| 573 | b = alpha * (b - mean) / torch.sqrt(var + epsilon) + beta |
| 574 |
no test coverage detected