| 30 | return self.replace_batchnorm_to_scale() |
| 31 | |
| 32 | def replace_op(self, op_name: str, replace_to: Operation): |
| 33 | if op_name not in self._graph.operations: |
| 34 | raise KeyError(f'Operation {op_name} is not in current graph') |
| 35 | operation = self._graph.operations[op_name] |
| 36 | |
| 37 | replace_to.inputs.clear() |
| 38 | replace_to.inputs.extend(operation.inputs) |
| 39 | for input_var in operation.inputs: |
| 40 | dest_idx = input_var.dest_ops.index(operation) |
| 41 | input_var.dest_ops[dest_idx] = replace_to |
| 42 | |
| 43 | replace_to.outputs.clear() |
| 44 | replace_to.outputs.extend(operation.outputs) |
| 45 | for output_var in operation.outputs: |
| 46 | output_var.source_op = replace_to |
| 47 | |
| 48 | replace_to.parameters.clear() |
| 49 | replace_to.parameters.extend(operation.parameters) |
| 50 | |
| 51 | self._graph.operations[op_name] = replace_to |
| 52 | |
| 53 | def replace_var(self, var_name: str, replace_to: Variable): |
| 54 | if var_name not in self._graph.variables: |