Fuse Layernormalization with pattern matching.
(self, exclusive_search: bool = False)
| 691 | op.type = "Gemm" |
| 692 | |
| 693 | def fuse_layernorm(self, exclusive_search: bool = False): |
| 694 | """Fuse Layernormalization with pattern matching.""" |
| 695 | |
| 696 | def _fuse(rm1: Operation, rm2: Operation, |
| 697 | eps: Operation, scale: torch.Tensor, |
| 698 | bias: torch.Tensor, layernorm_input_var: Variable, |
| 699 | layernorm_output_var: Variable) -> Operation: |
| 700 | |
| 701 | if rm2.type == rm1.type == 'ReduceMean': |
| 702 | if 'axes' not in rm1.attributes: return None |
| 703 | if 'axes' not in rm2.attributes: return None |
| 704 | if rm1.attributes['axes'] != rm2.attributes['axes']: return None |
| 705 | layernorm_axis = rm1.attributes['axes'] |
| 706 | if isinstance(layernorm_axis, list): |
| 707 | if len(layernorm_axis) != 1: return None |
| 708 | layernorm_axis = layernorm_axis[0] |
| 709 | if not isinstance(layernorm_axis, int): return None |
| 710 | else: layernorm_axis = -1 |
| 711 | |
| 712 | if not eps.inputs[-1].is_parameter: return None |
| 713 | value = eps.inputs[-1].value |
| 714 | value = convert_any_to_torch_tensor(value).cpu() |
| 715 | if value.numel() != 1: return None |
| 716 | layernorm_eps = value.item() |
| 717 | |
| 718 | layernorm_output_var.source_op.outputs.clear() |
| 719 | layernorm = self.graph.create_operation( |
| 720 | op_type = 'LayerNormalization', |
| 721 | attributes = {'axis': layernorm_axis, 'epsilon': layernorm_eps, 'stash_type': 0}, |
| 722 | inputs = [layernorm_input_var, self.graph.create_variable(value=scale, is_parameter=True)], |
| 723 | outputs = [layernorm_output_var]) |
| 724 | |
| 725 | if bias is not None: |
| 726 | self.graph.create_link_with_op( |
| 727 | variable=self.graph.create_variable(value=bias, is_parameter=True), |
| 728 | A=None, B=layernorm) |
| 729 | return layernorm |
| 730 | |
| 731 | search_engine = SearchableGraph(graph=self.graph) |
| 732 | fused = False |
| 733 | |
| 734 | # pattern 1: |
| 735 | # --- --- --- --- --- --- --- -- |
| 736 | # | | |
| 737 | # ***(0) --- ReduceMean(1) --- Sub(2) --- Pow(3) --- ReduceMean(4) --- Add(5) --- Sqrt(6) --- Div(7) --- Mul(8) --- (Add)(9) |
| 738 | # | | |
| 739 | # --- --- --- --- |
| 740 | matches = search_engine.pattern_matching( |
| 741 | patterns=[lambda x: True, lambda x: x.type in {'ReduceMean', 'GlobalAveragePool'}, 'Sub', 'Pow', |
| 742 | lambda x: x.type in {'ReduceMean', 'GlobalAveragePool'}, 'Add', 'Sqrt', 'Div', 'Mul'], |
| 743 | edges=[[0, 1], [0, 2], [1, 2], [2, 3], [2, 7], [3, 4], [4, 5], [5, 6], [6, 7], [7, 8]], exclusive=exclusive_search) |
| 744 | |
| 745 | for _, rm1, sub, pow, rm2, add, sqrt, div, mul in matches: |
| 746 | layernorm_ops = [rm1, sub, pow, rm2, add, sqrt, div, mul] |
| 747 | |
| 748 | layernorm_scale = mul.inputs[-1].value |
| 749 | layernorm_output_var = div.outputs[0] |
| 750 | layernorm_input_var = sub.inputs[0] |
no test coverage detected