(
x: TensorInfo, filter: TensorInfo, convInfo: backend_util.Conv2DInfo,
backend: NodeJSKernelBackend)
| 39 | }; |
| 40 | |
| 41 | export function conv2dImpl( |
| 42 | x: TensorInfo, filter: TensorInfo, convInfo: backend_util.Conv2DInfo, |
| 43 | backend: NodeJSKernelBackend): Tensor4D { |
| 44 | if (convInfo.padInfo.type !== 'VALID' && convInfo.padInfo.type !== 'SAME' && |
| 45 | convInfo.padInfo.type !== 'EXPLICIT') { |
| 46 | throw new Error( |
| 47 | `TF Backend supports only 'valid' and 'same' padding ` + |
| 48 | `while padding was ${convInfo.padInfo.type}`); |
| 49 | } |
| 50 | const strides = [1, convInfo.strideHeight, convInfo.strideWidth, 1]; |
| 51 | const padding = convInfo.padInfo.type; |
| 52 | const dataFormat = convInfo.dataFormat === 'channelsLast' ? 'NHWC' : 'NCHW'; |
| 53 | const dilations = [1, convInfo.dilationHeight, convInfo.dilationWidth, 1]; |
| 54 | const opAttrs = [ |
| 55 | createTensorsTypeOpAttr('T', x.dtype), |
| 56 | {name: 'strides', type: backend.binding.TF_ATTR_INT, value: strides}, |
| 57 | {name: 'padding', type: backend.binding.TF_ATTR_STRING, value: padding}, |
| 58 | { |
| 59 | name: 'data_format', |
| 60 | type: backend.binding.TF_ATTR_STRING, |
| 61 | value: dataFormat |
| 62 | }, |
| 63 | {name: 'use_cudnn_on_gpu', type: backend.binding.TF_ATTR_BOOL, value: true}, |
| 64 | {name: 'dilations', type: backend.binding.TF_ATTR_INT, value: dilations}, |
| 65 | ]; |
| 66 | if (padding === 'EXPLICIT') { |
| 67 | const padValue = [ |
| 68 | convInfo.padInfo.top, convInfo.padInfo.bottom, convInfo.padInfo.left, |
| 69 | convInfo.padInfo.right |
| 70 | ]; |
| 71 | opAttrs.push({ |
| 72 | name: 'explicit_paddings', |
| 73 | type: backend.binding.TF_ATTR_INT, |
| 74 | value: dataFormat === 'NHWC' ? [0, 0, ...padValue, 0, 0] : |
| 75 | [0, 0, 0, 0, ...padValue] |
| 76 | }); |
| 77 | } |
| 78 | return backend.executeSingleOutput(Conv2D, opAttrs, [x, filter]) as Tensor4D; |
| 79 | } |
no test coverage detected
searching dependent graphs…