MCPcopy
hub / github.com/OpenPPL/ppq / fuse_gelu

Method fuse_gelu

ppq/IR/morph.py:845–883  ·  view source on GitHub ↗

Fuse Gelu Pattern: * - Div - Erf - Add - Mul - Mul | | -------------------

(self)

Source from the content-addressed store, hash-verified

843 ppq_warning('No valid Skip Layernorm pattern was found, check your graph again.')
844
845 def fuse_gelu(self):
846 """ Fuse Gelu
847
848 Pattern: * - Div - Erf - Add - Mul - Mul
849 | |
850 -------------------
851 """
852 fused = False
853 search_engine = SearchableGraph(graph=self.graph)
854
855 matches = search_engine.pattern_matching(
856 patterns=[lambda x: True, 'Div', 'Erf', 'Add', 'Mul', 'Mul'],
857 edges=[[0, 1], [1, 2], [2, 3], [3, 4], [0, 4], [4, 5]], exclusive=True)
858
859 for _, div, erf, add, mul1, mul2 in matches:
860 removing_var = []
861 removing_var.extend(div.outputs)
862 removing_var.extend(erf.outputs)
863 removing_var.extend(add.outputs)
864 removing_var.extend(mul1.outputs)
865
866 self.graph.remove_operation(div)
867 self.graph.remove_operation(erf)
868 self.graph.remove_operation(add)
869 self.graph.remove_operation(mul1)
870 for var in removing_var:
871 self.graph.remove_variable(var)
872
873 input_vars = _.outputs.copy()
874 output_vars = mul2.outputs.copy()
875
876 self.graph.remove_operation(mul2)
877 self.graph.create_operation(op_type='Gelu', inputs=input_vars, outputs=output_vars)
878 assert len(input_vars) == 1, 'Fusion failed, Pattern unrecognized.'
879 fused = True
880
881 # final check, if no valid pattern was found, we give a warning.
882 if not fused:
883 ppq_warning('No valid Gelu pattern was found, check your graph again.')
884
885 def fuse_bias_add(self):
886 """

Callers 1

Calls 7

pattern_matchingMethod · 0.95
SearchableGraphClass · 0.90
ppq_warningFunction · 0.90
remove_operationMethod · 0.80
remove_variableMethod · 0.80
create_operationMethod · 0.80
copyMethod · 0.45

Tested by

no test coverage detected