(self, graph: BaseGraph,
dataloader: Iterable, executor: BaseGraphExecutor,
**kwargs)
| 156 | return r_value, s_value, processed_values |
| 157 | |
| 158 | def optimize(self, graph: BaseGraph, |
| 159 | dataloader: Iterable, executor: BaseGraphExecutor, |
| 160 | **kwargs) -> None: |
| 161 | with torch.no_grad(): |
| 162 | for name in self.interested_layers: |
| 163 | # op check |
| 164 | if name not in graph.operations: |
| 165 | raise KeyError(f'Operation {name} is not in current graph.') |
| 166 | op1 = graph.operations[name] |
| 167 | if op1.type not in {'Gemm', 'MatMul', 'Conv', 'ConvTranspose'}: |
| 168 | raise TypeError(f'Operation {op1.name} can not be splited, op type is invalid({op1.type})') |
| 169 | if not op1.inputs[1].is_parameter: |
| 170 | raise ValueError(f'Operation {op1.name} can not be splited, input 1 is not parameter.') |
| 171 | if isinstance(op1, QuantableOperation): |
| 172 | raise TypeError(f'Can not split a quantized operation, ' |
| 173 | 'Layer Split Pass should only be invoked as a pre-quant optimziation.') |
| 174 | |
| 175 | r_value, s_value, processed_values = self.h_split(op1) |
| 176 | |
| 177 | if processed_values > 0: |
| 178 | # clone current operation |
| 179 | op2 = graph.create_operation( |
| 180 | op_type=op1.type, attributes=op1.attributes.copy(), |
| 181 | platform=op1.platform) |
| 182 | input_var, output_var = op1.inputs[0], op1.outputs[0] |
| 183 | graph.create_link_with_op( |
| 184 | variable=op1.inputs[0], A=input_var.source_op, |
| 185 | B=op2) |
| 186 | |
| 187 | # create weight for cloned operation. |
| 188 | graph.create_variable(value=op1.inputs[1].value.clone(), is_parameter=True, dest_ops=[op2]) |
| 189 | |
| 190 | # set splited value |
| 191 | op1.inputs[1].value.copy_(r_value) |
| 192 | op2.inputs[1].value.copy_(s_value) |
| 193 | |
| 194 | op1.outputs.clear() |
| 195 | adder = graph.create_operation(op_type='Add', platform=op1.platform, outputs=[output_var]) |
| 196 | output_var.source_op = adder |
| 197 | |
| 198 | graph.create_link_with_op(A=op1, B=adder) |
| 199 | graph.create_link_with_op(A=op2, B=adder) |
| 200 | |
| 201 | |
| 202 | class MetaxGemmSplitPass(QuantizationOptimizationPass): |
nothing calls this directly
no test coverage detected