Since PPQ 0.6.4, GraphDecomposer is introduced to split some complex operations For example, Gemm can be split with MatMul with Bias add. Gemm General Matrix multiplication: https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms#Level_3 A' = transpose(A) if transA else A
| 1076 | |
| 1077 | |
| 1078 | class GraphDecomposer(GraphCommandProcessor): |
| 1079 | """Since PPQ 0.6.4, GraphDecomposer is introduced to split some complex |
| 1080 | operations For example, Gemm can be split with MatMul with Bias add. |
| 1081 | |
| 1082 | Gemm |
| 1083 | General Matrix multiplication: https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms#Level_3 |
| 1084 | |
| 1085 | A' = transpose(A) if transA else A |
| 1086 | |
| 1087 | B' = transpose(B) if transB else B |
| 1088 | |
| 1089 | Compute Y = alpha * A' * B' + beta * C, where input tensor A has shape (M, K) or (K, M), |
| 1090 | input tensor B has shape (K, N) or (N, K), input tensor C is broadcastable to shape (M, N), |
| 1091 | and output tensor Y has shape (M, N). A will be transposed before doing the computation if attribute transA is non-zero, |
| 1092 | same for B and transB. |
| 1093 | |
| 1094 | This operator supports unidirectional broadcasting (tensor C should be unidirectional broadcastable to tensor A * B); |
| 1095 | for more details please check the doc. This operator has optional inputs/outputs. |
| 1096 | |
| 1097 | See the doc for more details about the representation of optional arguments. |
| 1098 | An empty string may be used in the place of an actual argument's name to indicate a missing argument. |
| 1099 | Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. |
| 1100 | |
| 1101 | Attributes |
| 1102 | alpha : float (default is 1.0) |
| 1103 | Scalar multiplier for the product of input tensors A * B. |
| 1104 | |
| 1105 | beta : float (default is 1.0) |
| 1106 | Scalar multiplier for input tensor C. |
| 1107 | |
| 1108 | transA : int (default is 0) |
| 1109 | Whether A should be transposed |
| 1110 | |
| 1111 | transB : int (default is 0) |
| 1112 | Whether B should be transposed |
| 1113 | """ |
| 1114 | |
| 1115 | def process(self, command: GraphCommand) -> Any: |
| 1116 | return super().process(command) |
| 1117 | |
| 1118 | def _acceptable_command_types(self) -> List[GraphCommandType]: |
| 1119 | return super()._acceptable_command_types |
| 1120 | |
| 1121 | def decompose_gemm(self): |
| 1122 | graph = self.graph |
| 1123 | interested_ops = [] |
| 1124 | for operation in graph.operations.values(): |
| 1125 | if operation.type == 'Gemm': |
| 1126 | interested_ops.append(operation) |
| 1127 | |
| 1128 | for op in interested_ops: |
| 1129 | assert isinstance(op, Operation) |
| 1130 | output_var = op.outputs[0] |
| 1131 | |
| 1132 | if op.num_of_input == 3: |
| 1133 | bias_add = graph.create_operation(op_type='Add', platform=op.platform) |
| 1134 | bias_var = op.inputs[-1] |
| 1135 |
no outgoing calls
no test coverage detected