(self, graph: BaseGraph,
dataloader: Iterable, executor: BaseGraphExecutor,
**kwargs)
| 231 | |
| 232 | # Implementation of Gemm Split will move to IR.morph soon. |
| 233 | def optimize(self, graph: BaseGraph, |
| 234 | dataloader: Iterable, executor: BaseGraphExecutor, |
| 235 | **kwargs) -> None: |
| 236 | |
| 237 | interested_ops = [] |
| 238 | for operation in graph.operations.values(): |
| 239 | if operation.type == 'GRU': |
| 240 | interested_ops.append(operation) |
| 241 | |
| 242 | for op in interested_ops: |
| 243 | assert isinstance(op, Operation) |
| 244 | # fetch all related variables |
| 245 | rnn_x, rnn_w, rnn_r, rnn_b, _, rnn_h = op.inputs |
| 246 | hidden_size = op.attributes['hidden_size'] |
| 247 | |
| 248 | # Take a further look at |
| 249 | # https://github.com/onnx/onnx/blob/main/docs/Operators.md#GRU |
| 250 | Wz = rnn_w.value[0, hidden_size * 0: hidden_size * 1] |
| 251 | Wr = rnn_w.value[0, hidden_size * 1: hidden_size * 2] |
| 252 | Wh = rnn_w.value[0, hidden_size * 2: hidden_size * 3] |
| 253 | |
| 254 | Rz = rnn_r.value[0, hidden_size * 0: hidden_size * 1] |
| 255 | Rr = rnn_r.value[0, hidden_size * 1: hidden_size * 2] |
| 256 | Rh = rnn_r.value[0, hidden_size * 2: hidden_size * 3] |
| 257 | |
| 258 | Wbz = rnn_b.value[0, hidden_size * 0: hidden_size * 1] |
| 259 | Wbr = rnn_b.value[0, hidden_size * 1: hidden_size * 2] |
| 260 | Wbh = rnn_b.value[0, hidden_size * 2: hidden_size * 3] |
| 261 | |
| 262 | Rbz = rnn_b.value[0, hidden_size * 3: hidden_size * 4] |
| 263 | Rbr = rnn_b.value[0, hidden_size * 4: hidden_size * 5] |
| 264 | Rbh = rnn_b.value[0, hidden_size * 5: hidden_size * 6] |
| 265 | |
| 266 | # create operations |
| 267 | op1 = graph.create_operation(op_type='Gemm', attributes={'transB': 1}) |
| 268 | op2 = graph.create_operation(op_type='Gemm', attributes={'transB': 1}) |
| 269 | op3 = graph.create_operation(op_type='Add') |
| 270 | op4 = graph.create_operation(op_type='Sigmoid') |
| 271 | op5 = graph.create_operation(op_type='Slice') |
| 272 | op6 = graph.create_operation(op_type='Slice') |
| 273 | op7 = graph.create_operation(op_type='Gemm', attributes={'transB': 1}) |
| 274 | op8 = graph.create_operation(op_type='Gemm', attributes={'transB': 1}) |
| 275 | op9 = graph.create_operation(op_type='Mul') |
| 276 | op10 = graph.create_operation(op_type='Mul') |
| 277 | op11 = graph.create_operation(op_type='Sub') |
| 278 | op12 = graph.create_operation(op_type='Add') |
| 279 | op13 = graph.create_operation(op_type='Mul') |
| 280 | op14 = graph.create_operation(op_type='Tanh') |
| 281 | op15 = graph.create_operation(op_type='Add') |
| 282 | |
| 283 | # create parameter variables |
| 284 | # 为了加速运算,我们将Wz, Wr合并成Wzr, Rzh等同理 |
| 285 | # 参考 https://github.com/onnx/onnx/blob/main/docs/Operators.md#GRU |
| 286 | Wzr_var = graph.create_variable(value=torch.cat([Wz, Wr]), is_parameter=True) |
| 287 | Rzr_var = graph.create_variable(value=torch.cat([Rz, Rr]), is_parameter=True) |
| 288 | Wbzr_var = graph.create_variable(value=torch.cat([Wbz, Wbr]), is_parameter=True) |
| 289 | Rbzr_var = graph.create_variable(value=torch.cat([Rbz, Rbr]), is_parameter=True) |
| 290 |
no test coverage detected