Graph Formatter offers a bunch of graph editing functions that helps modifying your graph.
| 153 | |
| 154 | |
| 155 | class GraphFormatter(GraphCommandProcessor): |
| 156 | """ Graph Formatter offers a bunch of graph editing functions that helps modifying your graph. """ |
| 157 | def _acceptable_command_types(self) -> List[GraphCommandType]: |
| 158 | return [ |
| 159 | GraphCommandType.FORMAT_CLIP, |
| 160 | GraphCommandType.FORMAT_PAD, |
| 161 | GraphCommandType.FORMAT_GATHER, |
| 162 | GraphCommandType.FORMAT_CAST, |
| 163 | GraphCommandType.FORMAT_INT64_CONSTANT, |
| 164 | GraphCommandType.DELETE_ISOLATED, |
| 165 | GraphCommandType.FORMAT_PARAMETERS, |
| 166 | GraphCommandType.FORMAT_CONSTANT_INPUT, |
| 167 | GraphCommandType.FORMAT_SLICE, |
| 168 | GraphCommandType.TRUNCATE_ON_VAR, |
| 169 | GraphCommandType.FORMAT_RESIZE, |
| 170 | GraphCommandType.REMOVE_IDENTITY, |
| 171 | GraphCommandType.CONVERT_TO_TENSOR, |
| 172 | ] |
| 173 | |
| 174 | def process(self, command: GraphCommand) -> Any: |
| 175 | if command.command_type == GraphCommandType.FORMAT_CLIP: |
| 176 | return self.format_clip() |
| 177 | if command.command_type == GraphCommandType.FORMAT_PAD: |
| 178 | return self.format_pad() |
| 179 | if command.command_type == GraphCommandType.FORMAT_GATHER: |
| 180 | return self.format_gather() |
| 181 | if command.command_type == GraphCommandType.FORMAT_CAST: |
| 182 | return self.format_cast() |
| 183 | if command.command_type == GraphCommandType.DELETE_ISOLATED: |
| 184 | return self.delete_isolated() |
| 185 | if command.command_type == GraphCommandType.FORMAT_INT64_CONSTANT: |
| 186 | return self.format_int64_constant() |
| 187 | if command.command_type == GraphCommandType.FORMAT_PARAMETERS: |
| 188 | return self.format_parameter() |
| 189 | if command.command_type == GraphCommandType.FORMAT_CONSTANT_INPUT: |
| 190 | return self.remove_constant_input() |
| 191 | if command.command_type == GraphCommandType.FORMAT_SLICE: |
| 192 | return self.format_slice() |
| 193 | if command.command_type == GraphCommandType.FORMAT_RESIZE: |
| 194 | return self.format_resize() |
| 195 | if command.command_type == GraphCommandType.TRUNCATE_ON_VAR: |
| 196 | assert isinstance(command, TruncateGraphCommand), f'Use TruncateGraphCommand here.' |
| 197 | return self.truncate_on_var(command.var, command.mark_as_output) |
| 198 | if command.command_type == GraphCommandType.REMOVE_IDENTITY: |
| 199 | return self.remove_identity() |
| 200 | if command.command_type == GraphCommandType.CONVERT_TO_TENSOR: |
| 201 | return self.convert_to_tensor() |
| 202 | |
| 203 | def format_slice(self) -> None: |
| 204 | """ |
| 205 | Slice: opset1 格式跟其他的不太一样,这个 pass 将 opset1 的 slice 强行转换为 opset 11 |
| 206 | """ |
| 207 | interested_ops = [] |
| 208 | for operation in self.graph.operations.values(): |
| 209 | if operation.type == 'Slice': |
| 210 | if 'starts' in operation.attributes: |
| 211 | assert 'starts' in operation.attributes and 'ends' in operation.attributes, ( |
| 212 | f'Invalid Slice Operation Format, Slice operation is expected to have axes, ' |
no outgoing calls
no test coverage detected