Fuse Gelu Pattern: * - Div - Erf - Add - Mul - Mul | | -------------------
(self)
| 843 | ppq_warning('No valid Skip Layernorm pattern was found, check your graph again.') |
| 844 | |
| 845 | def fuse_gelu(self): |
| 846 | """ Fuse Gelu |
| 847 | |
| 848 | Pattern: * - Div - Erf - Add - Mul - Mul |
| 849 | | | |
| 850 | ------------------- |
| 851 | """ |
| 852 | fused = False |
| 853 | search_engine = SearchableGraph(graph=self.graph) |
| 854 | |
| 855 | matches = search_engine.pattern_matching( |
| 856 | patterns=[lambda x: True, 'Div', 'Erf', 'Add', 'Mul', 'Mul'], |
| 857 | edges=[[0, 1], [1, 2], [2, 3], [3, 4], [0, 4], [4, 5]], exclusive=True) |
| 858 | |
| 859 | for _, div, erf, add, mul1, mul2 in matches: |
| 860 | removing_var = [] |
| 861 | removing_var.extend(div.outputs) |
| 862 | removing_var.extend(erf.outputs) |
| 863 | removing_var.extend(add.outputs) |
| 864 | removing_var.extend(mul1.outputs) |
| 865 | |
| 866 | self.graph.remove_operation(div) |
| 867 | self.graph.remove_operation(erf) |
| 868 | self.graph.remove_operation(add) |
| 869 | self.graph.remove_operation(mul1) |
| 870 | for var in removing_var: |
| 871 | self.graph.remove_variable(var) |
| 872 | |
| 873 | input_vars = _.outputs.copy() |
| 874 | output_vars = mul2.outputs.copy() |
| 875 | |
| 876 | self.graph.remove_operation(mul2) |
| 877 | self.graph.create_operation(op_type='Gelu', inputs=input_vars, outputs=output_vars) |
| 878 | assert len(input_vars) == 1, 'Fusion failed, Pattern unrecognized.' |
| 879 | fused = True |
| 880 | |
| 881 | # final check, if no valid pattern was found, we give a warning. |
| 882 | if not fused: |
| 883 | ppq_warning('No valid Gelu pattern was found, check your graph again.') |
| 884 | |
| 885 | def fuse_bias_add(self): |
| 886 | """ |
no test coverage detected