(
input: TensorInfo, filter: TensorInfo, convInfo: backend_util.Conv2DInfo,
backend: NodeJSKernelBackend)
| 43 | }; |
| 44 | |
| 45 | export function depthwiseConv2dNativeImpl( |
| 46 | input: TensorInfo, filter: TensorInfo, convInfo: backend_util.Conv2DInfo, |
| 47 | backend: NodeJSKernelBackend): Tensor4D { |
| 48 | if (convInfo.padInfo.type !== 'VALID' && convInfo.padInfo.type !== 'SAME' && |
| 49 | convInfo.padInfo.type !== 'EXPLICIT') { |
| 50 | throw new Error( |
| 51 | `TF Backend supports only 'valid' and 'same' padding ` + |
| 52 | `while padding was ${convInfo.padInfo.type}`); |
| 53 | } |
| 54 | const strides = [1, convInfo.strideHeight, convInfo.strideWidth, 1]; |
| 55 | const padding = convInfo.padInfo.type; |
| 56 | const dataFormat = convInfo.dataFormat === 'channelsLast' ? 'NHWC' : 'NCHW'; |
| 57 | const dilations = [1, convInfo.dilationHeight, convInfo.dilationWidth, 1]; |
| 58 | const opAttrs = [ |
| 59 | createTensorsTypeOpAttr('T', input.dtype), |
| 60 | {name: 'strides', type: backend.binding.TF_ATTR_INT, value: strides}, |
| 61 | {name: 'padding', type: backend.binding.TF_ATTR_STRING, value: padding}, { |
| 62 | name: 'data_format', |
| 63 | type: backend.binding.TF_ATTR_STRING, |
| 64 | value: dataFormat |
| 65 | }, |
| 66 | {name: 'dilations', type: backend.binding.TF_ATTR_INT, value: dilations} |
| 67 | ]; |
| 68 | if (padding === 'EXPLICIT') { |
| 69 | const padValue = [ |
| 70 | convInfo.padInfo.top, convInfo.padInfo.bottom, convInfo.padInfo.left, |
| 71 | convInfo.padInfo.right |
| 72 | ]; |
| 73 | opAttrs.push({ |
| 74 | name: 'explicit_paddings', |
| 75 | type: backend.binding.TF_ATTR_INT, |
| 76 | value: dataFormat === 'NHWC' ? [0, 0, ...padValue, 0, 0] : |
| 77 | [0, 0, 0, 0, ...padValue] |
| 78 | }); |
| 79 | } |
| 80 | return backend.executeSingleOutput( |
| 81 | DepthwiseConv2dNative, opAttrs, [input, filter]) as Tensor4D; |
| 82 | } |
no test coverage detected
searching dependent graphs…