MCPcopy
hub / github.com/OpenPPL/ppq / fuse_gemm

Method fuse_gemm

ppq/IR/morph.py:617–691  ·  view source on GitHub ↗

Fuse MatMul + add into a singal Gemm Single Matmul will be replaced with Gemm Returns: _type_: _description_

(self)

Source from the content-addressed store, hash-verified

615 self.graph.append_variable(bias_var)
616
617 def fuse_gemm(self):
618 """Fuse MatMul + add into a singal Gemm
619 Single Matmul will be replaced with Gemm
620
621 Returns:
622 _type_: _description_
623 """
624
625 def _is_replaceable(op: Operation) -> bool:
626 if op.inputs[0].is_parameter == False and op.inputs[1].is_parameter == False:
627 return False
628 else:
629 return True
630
631 search_engine = SearchableGraph(graph=self.graph)
632 patterns = search_engine.pattern_matching(patterns=["MatMul", "Add"], edges=[[0, 1]], exclusive=True)
633 for pattern in patterns:
634 matmul, add = pattern
635
636 if _is_replaceable(matmul) == False:
637 continue
638
639 matmul.type = "Gemm"
640
641 matmul_out = matmul.outputs[0]
642 add_out = add.outputs[0]
643
644 if matmul.inputs[0].is_parameter:
645 temp = matmul.inputs[0]
646 matmul.inputs[0] = matmul.inputs[1]
647 matmul.inputs[1] = temp
648
649 assert len(add.inputs) == 2, "Oops, seems we got some problem here."
650 var1, var2 = add.inputs
651 bias_var = None
652
653 if var1.source_op == matmul and var2.is_parameter:
654 bias_var = var2
655
656 if var2.source_op == matmul and var1.is_parameter:
657 bias_var = var1
658
659 # can not find a valid bias, just skip add.
660 if bias_var is None:
661 continue
662
663 if len(bias_var.value.shape) == 1:
664 if bias_var.value.shape[0] == matmul.parameters[0].value.shape[-1]:
665 matmul.attributes["transB"] = 1
666 weight_val = matmul.parameters[0].value
667
668 matmul.parameters[0].value = weight_val.transpose(-1, -2)
669
670 bias_var.dest_ops.clear()
671 add.inputs.remove(bias_var)
672
673 # remove bias add, move bias to matmul
674 self.graph.remove_operation(add)

Callers 1

Calls 6

pattern_matchingMethod · 0.95
SearchableGraphClass · 0.90
remove_operationMethod · 0.80
create_link_with_opMethod · 0.80
create_link_with_varMethod · 0.80
clearMethod · 0.45

Tested by

no test coverage detected