Fuse Pattern like Conv + Add, ConvTranspose + Add, Gemm + Add This fusion will require a constant input as bias.
(self)
| 883 | ppq_warning('No valid Gelu pattern was found, check your graph again.') |
| 884 | |
| 885 | def fuse_bias_add(self): |
| 886 | """ |
| 887 | Fuse Pattern like Conv + Add, ConvTranspose + Add, Gemm + Add |
| 888 | This fusion will require a constant input as bias. |
| 889 | """ |
| 890 | graph = self.graph |
| 891 | for op in [_ for _ in graph.operations.values()]: |
| 892 | if op.type in {'Conv', 'ConvTranspose', 'Gemm'}: |
| 893 | # check if current op has only 1 downstream op |
| 894 | channel_dimension = 1 # NCHW, NCHWD, NCH |
| 895 | if op.type == 'Gemm': channel_dimension |
| 896 | if len(graph.get_downstream_operations(op)) == 1: |
| 897 | down = graph.get_downstream_operations(op)[0] |
| 898 | |
| 899 | if down.type == 'Add': |
| 900 | if down.num_of_parameter != 1: continue |
| 901 | |
| 902 | bias = down.parameters[0] |
| 903 | if op.type not in {'Gemm'}: |
| 904 | # check if it is a bias add |
| 905 | if not bias.value.dim() == op.parameters[0].value.dim(): continue |
| 906 | if not bias.value.squeeze().dim() == 1: continue |
| 907 | if bias.value.shape[channel_dimension] == 1: continue |
| 908 | bias.value = bias.value.squeeze() # conv bias can only be 1d |
| 909 | else: |
| 910 | # Gemm bias can be any shape. |
| 911 | # see https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gemm-11 |
| 912 | pass |
| 913 | |
| 914 | # ready for fusion |
| 915 | if op.num_of_input == 3: # already has a bias |
| 916 | pass |
| 917 | else: |
| 918 | graph.create_variable(is_parameter=True, value=bias.value, dest_ops=[op]) |
| 919 | graph.remove_operation(removing_op=down, keep_coherence=True) |
| 920 | |
| 921 | def fuse_scale(self): |
| 922 | "Fuse Conv + Mul or Conv + Add" |
no test coverage detected