(attr, input_shape, kernel_shape=None, op_type=None)
| 5 | |
| 6 | # attribute checker and preprocess |
| 7 | def process_attribute(attr, input_shape, kernel_shape=None, op_type=None): |
| 8 | # ASSUME input is 2D |
| 9 | # assert len(input_shape) == 2 |
| 10 | # Get default attr value |
| 11 | auto_pad = attr.get('auto_pad', 'NOTSET') |
| 12 | strides = attr.get('strides', [1, 1]) |
| 13 | dilations = attr.get('dilations', [1, 1]) |
| 14 | kernels = attr.get('kernel_shape', kernel_shape) |
| 15 | pad_needed = None |
| 16 | |
| 17 | if op_type == 'ConvTranspose' and 'output_shape' in attr: |
| 18 | output_shape = attr['output_shape'] |
| 19 | out_pad = [0, 1] if output_shape % 2 != 0 else [0, 0] |
| 20 | pad_needed = [(input_shape[i] - 1) * strides[i] + dilations[i] * (kernels[i] - 1) + 1 + out_pad[i] - |
| 21 | output_shape[i] for i in range(len(input_shape))] |
| 22 | |
| 23 | if auto_pad != 'NOTSET': |
| 24 | if 'pads' in attr: |
| 25 | logger.warning('auto_pad is conflict with pads attribute. Use pads here.') |
| 26 | elif auto_pad == 'VALID': |
| 27 | attr['pads'] = [0, 0, 0, 0] |
| 28 | elif auto_pad in ('SAME_UPPER', 'SAME_LOWER'): |
| 29 | if op_type == 'ConvTranspose': |
| 30 | # `output_padding` is only used to find output shape, but does not actually add zero-padding to output |
| 31 | out_pad = attr.get('output_padding', [0, 0]) |
| 32 | output_shape = [input_shape[i] * strides[i] for i in range(len(input_shape))] |
| 33 | pad_needed = [(input_shape[i] - 1) * strides[i] + dilations[i] * (kernels[i] - 1) + 1 + out_pad[i] - |
| 34 | output_shape[i] for i in range(len(input_shape))] |
| 35 | else: |
| 36 | output_shape = [(input_shape[i] + strides[i] - 1) // strides[i] for i in range(len(input_shape))] |
| 37 | pad_needed = [(output_shape[i] - 1) * strides[i] + dilations[i] * (kernels[i] - 1) + 1 - input_shape[i] |
| 38 | for i in range(len(input_shape))] |
| 39 | else: |
| 40 | raise ValueError(f'Invalid auto_pad value {auto_pad}') |
| 41 | |
| 42 | if pad_needed is not None: |
| 43 | pads = [] |
| 44 | for item in pad_needed: |
| 45 | pads.append((item if auto_pad == 'SAME_UPPER' else item + 1) // 2) |
| 46 | # onnx pads format should be as follow [x1_begin, x2_begin...x1_end, x2_end,...] |
| 47 | pads = pads + [pad_needed[i] - p for i, p in enumerate(pads)] |
| 48 | attr['pads'] = pads |
| 49 | # onnx pads attribute cannot be used simultaneously with auto_pad attribute |
| 50 | attr.pop('auto_pad') |
| 51 | |
| 52 | |
| 53 | def preprocess_attr(attr, op_type=None): |
no test coverage detected