MCPcopy
hub / github.com/OpenPPL/ppq / optimize

Method optimize

ppq/quantization/optim/morph.py:233–371  ·  view source on GitHub ↗
(self, graph: BaseGraph,
                 dataloader: Iterable, executor: BaseGraphExecutor,
                 **kwargs)

Source from the content-addressed store, hash-verified

231
232 # Implementation of Gemm Split will move to IR.morph soon.
233 def optimize(self, graph: BaseGraph,
234 dataloader: Iterable, executor: BaseGraphExecutor,
235 **kwargs) -> None:
236
237 interested_ops = []
238 for operation in graph.operations.values():
239 if operation.type == 'GRU':
240 interested_ops.append(operation)
241
242 for op in interested_ops:
243 assert isinstance(op, Operation)
244 # fetch all related variables
245 rnn_x, rnn_w, rnn_r, rnn_b, _, rnn_h = op.inputs
246 hidden_size = op.attributes['hidden_size']
247
248 # Take a further look at
249 # https://github.com/onnx/onnx/blob/main/docs/Operators.md#GRU
250 Wz = rnn_w.value[0, hidden_size * 0: hidden_size * 1]
251 Wr = rnn_w.value[0, hidden_size * 1: hidden_size * 2]
252 Wh = rnn_w.value[0, hidden_size * 2: hidden_size * 3]
253
254 Rz = rnn_r.value[0, hidden_size * 0: hidden_size * 1]
255 Rr = rnn_r.value[0, hidden_size * 1: hidden_size * 2]
256 Rh = rnn_r.value[0, hidden_size * 2: hidden_size * 3]
257
258 Wbz = rnn_b.value[0, hidden_size * 0: hidden_size * 1]
259 Wbr = rnn_b.value[0, hidden_size * 1: hidden_size * 2]
260 Wbh = rnn_b.value[0, hidden_size * 2: hidden_size * 3]
261
262 Rbz = rnn_b.value[0, hidden_size * 3: hidden_size * 4]
263 Rbr = rnn_b.value[0, hidden_size * 4: hidden_size * 5]
264 Rbh = rnn_b.value[0, hidden_size * 5: hidden_size * 6]
265
266 # create operations
267 op1 = graph.create_operation(op_type='Gemm', attributes={'transB': 1})
268 op2 = graph.create_operation(op_type='Gemm', attributes={'transB': 1})
269 op3 = graph.create_operation(op_type='Add')
270 op4 = graph.create_operation(op_type='Sigmoid')
271 op5 = graph.create_operation(op_type='Slice')
272 op6 = graph.create_operation(op_type='Slice')
273 op7 = graph.create_operation(op_type='Gemm', attributes={'transB': 1})
274 op8 = graph.create_operation(op_type='Gemm', attributes={'transB': 1})
275 op9 = graph.create_operation(op_type='Mul')
276 op10 = graph.create_operation(op_type='Mul')
277 op11 = graph.create_operation(op_type='Sub')
278 op12 = graph.create_operation(op_type='Add')
279 op13 = graph.create_operation(op_type='Mul')
280 op14 = graph.create_operation(op_type='Tanh')
281 op15 = graph.create_operation(op_type='Add')
282
283 # create parameter variables
284 # 为了加速运算,我们将Wz, Wr合并成Wzr, Rzh等同理
285 # 参考 https://github.com/onnx/onnx/blob/main/docs/Operators.md#GRU
286 Wzr_var = graph.create_variable(value=torch.cat([Wz, Wr]), is_parameter=True)
287 Rzr_var = graph.create_variable(value=torch.cat([Rz, Rr]), is_parameter=True)
288 Wbzr_var = graph.create_variable(value=torch.cat([Wbz, Wbr]), is_parameter=True)
289 Rbzr_var = graph.create_variable(value=torch.cat([Rbz, Rbr]), is_parameter=True)
290

Callers 1

quantizeMethod · 0.45

Calls 9

delete_hidden_vecMethod · 0.95
create_operationMethod · 0.80
create_variableMethod · 0.80
toMethod · 0.80
create_link_with_opMethod · 0.80
remove_operationMethod · 0.80
appendMethod · 0.45
clearMethod · 0.45

Tested by

no test coverage detected