MCPcopy
hub / github.com/OpenPPL/ppq / fuse_layernorm

Method fuse_layernorm

ppq/IR/morph.py:693–819  ·  view source on GitHub ↗

Fuse Layernormalization with pattern matching.

(self, exclusive_search: bool = False)

Source from the content-addressed store, hash-verified

691 op.type = "Gemm"
692
693 def fuse_layernorm(self, exclusive_search: bool = False):
694 """Fuse Layernormalization with pattern matching."""
695
696 def _fuse(rm1: Operation, rm2: Operation,
697 eps: Operation, scale: torch.Tensor,
698 bias: torch.Tensor, layernorm_input_var: Variable,
699 layernorm_output_var: Variable) -> Operation:
700
701 if rm2.type == rm1.type == 'ReduceMean':
702 if 'axes' not in rm1.attributes: return None
703 if 'axes' not in rm2.attributes: return None
704 if rm1.attributes['axes'] != rm2.attributes['axes']: return None
705 layernorm_axis = rm1.attributes['axes']
706 if isinstance(layernorm_axis, list):
707 if len(layernorm_axis) != 1: return None
708 layernorm_axis = layernorm_axis[0]
709 if not isinstance(layernorm_axis, int): return None
710 else: layernorm_axis = -1
711
712 if not eps.inputs[-1].is_parameter: return None
713 value = eps.inputs[-1].value
714 value = convert_any_to_torch_tensor(value).cpu()
715 if value.numel() != 1: return None
716 layernorm_eps = value.item()
717
718 layernorm_output_var.source_op.outputs.clear()
719 layernorm = self.graph.create_operation(
720 op_type = 'LayerNormalization',
721 attributes = {'axis': layernorm_axis, 'epsilon': layernorm_eps, 'stash_type': 0},
722 inputs = [layernorm_input_var, self.graph.create_variable(value=scale, is_parameter=True)],
723 outputs = [layernorm_output_var])
724
725 if bias is not None:
726 self.graph.create_link_with_op(
727 variable=self.graph.create_variable(value=bias, is_parameter=True),
728 A=None, B=layernorm)
729 return layernorm
730
731 search_engine = SearchableGraph(graph=self.graph)
732 fused = False
733
734 # pattern 1:
735 # --- --- --- --- --- --- --- --
736 # | |
737 # ***(0) --- ReduceMean(1) --- Sub(2) --- Pow(3) --- ReduceMean(4) --- Add(5) --- Sqrt(6) --- Div(7) --- Mul(8) --- (Add)(9)
738 # | |
739 # --- --- --- ---
740 matches = search_engine.pattern_matching(
741 patterns=[lambda x: True, lambda x: x.type in {'ReduceMean', 'GlobalAveragePool'}, 'Sub', 'Pow',
742 lambda x: x.type in {'ReduceMean', 'GlobalAveragePool'}, 'Add', 'Sqrt', 'Div', 'Mul'],
743 edges=[[0, 1], [0, 2], [1, 2], [2, 3], [2, 7], [3, 4], [4, 5], [5, 6], [6, 7], [7, 8]], exclusive=exclusive_search)
744
745 for _, rm1, sub, pow, rm2, add, sqrt, div, mul in matches:
746 layernorm_ops = [rm1, sub, pow, rm2, add, sqrt, div, mul]
747
748 layernorm_scale = mul.inputs[-1].value
749 layernorm_output_var = div.outputs[0]
750 layernorm_input_var = sub.inputs[0]

Callers 1

Calls 7

pattern_matchingMethod · 0.95
SearchableGraphClass · 0.90
ppq_warningFunction · 0.90
remove_variableMethod · 0.80
remove_operationMethod · 0.80
appendMethod · 0.45

Tested by

no test coverage detected