MCPcopy
hub / github.com/OpenPPL/ppq / _fuse

Method _fuse

ppq/IR/morph.py:696–729  ·  view source on GitHub ↗
(rm1: Operation, rm2: Operation, 
                  eps: Operation, scale: torch.Tensor, 
                  bias: torch.Tensor, layernorm_input_var: Variable, 
                  layernorm_output_var: Variable)

Source from the content-addressed store, hash-verified

694 """Fuse Layernormalization with pattern matching."""
695
696 def _fuse(rm1: Operation, rm2: Operation,
697 eps: Operation, scale: torch.Tensor,
698 bias: torch.Tensor, layernorm_input_var: Variable,
699 layernorm_output_var: Variable) -> Operation:
700
701 if rm2.type == rm1.type == 'ReduceMean':
702 if 'axes' not in rm1.attributes: return None
703 if 'axes' not in rm2.attributes: return None
704 if rm1.attributes['axes'] != rm2.attributes['axes']: return None
705 layernorm_axis = rm1.attributes['axes']
706 if isinstance(layernorm_axis, list):
707 if len(layernorm_axis) != 1: return None
708 layernorm_axis = layernorm_axis[0]
709 if not isinstance(layernorm_axis, int): return None
710 else: layernorm_axis = -1
711
712 if not eps.inputs[-1].is_parameter: return None
713 value = eps.inputs[-1].value
714 value = convert_any_to_torch_tensor(value).cpu()
715 if value.numel() != 1: return None
716 layernorm_eps = value.item()
717
718 layernorm_output_var.source_op.outputs.clear()
719 layernorm = self.graph.create_operation(
720 op_type = 'LayerNormalization',
721 attributes = {'axis': layernorm_axis, 'epsilon': layernorm_eps, 'stash_type': 0},
722 inputs = [layernorm_input_var, self.graph.create_variable(value=scale, is_parameter=True)],
723 outputs = [layernorm_output_var])
724
725 if bias is not None:
726 self.graph.create_link_with_op(
727 variable=self.graph.create_variable(value=bias, is_parameter=True),
728 A=None, B=layernorm)
729 return layernorm
730
731 search_engine = SearchableGraph(graph=self.graph)
732 fused = False

Callers

nothing calls this directly

Calls 5

create_operationMethod · 0.80
create_variableMethod · 0.80
create_link_with_opMethod · 0.80
clearMethod · 0.45

Tested by

no test coverage detected