MCPcopy
hub / github.com/OpenPPL/ppq / fuse_bias_add

Method fuse_bias_add

ppq/IR/morph.py:885–919  ·  view source on GitHub ↗

Fuse Pattern like Conv + Add, ConvTranspose + Add, Gemm + Add This fusion will require a constant input as bias.

(self)

Source from the content-addressed store, hash-verified

883 ppq_warning('No valid Gelu pattern was found, check your graph again.')
884
885 def fuse_bias_add(self):
886 """
887 Fuse Pattern like Conv + Add, ConvTranspose + Add, Gemm + Add
888 This fusion will require a constant input as bias.
889 """
890 graph = self.graph
891 for op in [_ for _ in graph.operations.values()]:
892 if op.type in {'Conv', 'ConvTranspose', 'Gemm'}:
893 # check if current op has only 1 downstream op
894 channel_dimension = 1 # NCHW, NCHWD, NCH
895 if op.type == 'Gemm': channel_dimension
896 if len(graph.get_downstream_operations(op)) == 1:
897 down = graph.get_downstream_operations(op)[0]
898
899 if down.type == 'Add':
900 if down.num_of_parameter != 1: continue
901
902 bias = down.parameters[0]
903 if op.type not in {'Gemm'}:
904 # check if it is a bias add
905 if not bias.value.dim() == op.parameters[0].value.dim(): continue
906 if not bias.value.squeeze().dim() == 1: continue
907 if bias.value.shape[channel_dimension] == 1: continue
908 bias.value = bias.value.squeeze() # conv bias can only be 1d
909 else:
910 # Gemm bias can be any shape.
911 # see https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gemm-11
912 pass
913
914 # ready for fusion
915 if op.num_of_input == 3: # already has a bias
916 pass
917 else:
918 graph.create_variable(is_parameter=True, value=bias.value, dest_ops=[op])
919 graph.remove_operation(removing_op=down, keep_coherence=True)
920
921 def fuse_scale(self):
922 "Fuse Conv + Mul or Conv + Add"

Callers 2

processMethod · 0.95
testFuseBias.pyFile · 0.80

Calls 3

create_variableMethod · 0.80
remove_operationMethod · 0.80

Tested by

no test coverage detected