Slice: opset1 格式跟其他的不太一样,这个 pass 将 opset1 的 slice 强行转换为 opset 11
(self)
| 201 | return self.convert_to_tensor() |
| 202 | |
| 203 | def format_slice(self) -> None: |
| 204 | """ |
| 205 | Slice: opset1 格式跟其他的不太一样,这个 pass 将 opset1 的 slice 强行转换为 opset 11 |
| 206 | """ |
| 207 | interested_ops = [] |
| 208 | for operation in self.graph.operations.values(): |
| 209 | if operation.type == 'Slice': |
| 210 | if 'starts' in operation.attributes: |
| 211 | assert 'starts' in operation.attributes and 'ends' in operation.attributes, ( |
| 212 | f'Invalid Slice Operation Format, Slice operation is expected to have axes, ' |
| 213 | 'starts and ends attributes with opset 1, ' |
| 214 | f'however your operation {operation.name}, do not have completed attributes') |
| 215 | interested_ops.append(operation) |
| 216 | |
| 217 | for slice in interested_ops: |
| 218 | assert isinstance(slice, Operation) |
| 219 | axes = slice.attributes.get('axes', None) |
| 220 | starts = slice.attributes['starts'] |
| 221 | ends = slice.attributes['ends'] |
| 222 | if axes == None: axes = [_ for _ in range(starts)] |
| 223 | |
| 224 | slice.attributes.pop('starts') |
| 225 | slice.attributes.pop('ends') |
| 226 | if 'axes' in slice.attributes: slice.attributes.pop('axes') |
| 227 | self.__add_constant_input(slice, convert_any_to_torch_tensor(starts)) |
| 228 | self.__add_constant_input(slice, convert_any_to_torch_tensor(ends)) |
| 229 | self.__add_constant_input(slice, convert_any_to_torch_tensor(axes)) |
| 230 | |
| 231 | def format_pad(self) -> None: |
| 232 | """ |
no test coverage detected