部分部署平台不支持 Constant Op 作为算子的输入 在这种情况下我们使用这个 pass 把它们切换成 Parameter Variable Some backend platform doesn't support Constant Op, we use this pass to replace it by forcing its value to be a parameter variable.
(self)
| 343 | if all(check): value = value.int() |
| 344 | |
| 345 | def remove_constant_input(self) -> None: |
| 346 | """部分部署平台不支持 Constant Op 作为算子的输入 |
| 347 | 在这种情况下我们使用这个 pass 把它们切换成 Parameter Variable |
| 348 | |
| 349 | Some backend platform doesn't support Constant |
| 350 | Op, we use this pass to replace it by forcing its value to be a |
| 351 | parameter variable.""" |
| 352 | removing_ops = [] |
| 353 | for op in self.graph.operations.values(): |
| 354 | if op.type == 'Constant': |
| 355 | assert len(op.outputs) == 1, ( |
| 356 | f'Constant Operation {op.name} has more than 1 output, is there a network parsing error?') |
| 357 | removing_ops.append(op) |
| 358 | |
| 359 | for const_op in removing_ops: |
| 360 | assert isinstance(const_op, Operation) |
| 361 | constant_value = const_op.attributes['value'] |
| 362 | output_var = const_op.outputs[0] |
| 363 | output_var._is_parameter = True |
| 364 | output_var.value = constant_value |
| 365 | self.graph.remove_operation(removing_op=const_op) |
| 366 | |
| 367 | def truncate_on_var(self, var: Variable, mark_as_output: bool): |
| 368 | """从一个指定位置将图截断. |
no test coverage detected