(args: {
inputs: StridedSliceInputs,
backend: MathBackendCPU,
attrs: StridedSliceAttrs
})
| 24 | import {stridedSliceImpl} from './StridedSlice_impl'; |
| 25 | |
| 26 | export function stridedSlice(args: { |
| 27 | inputs: StridedSliceInputs, |
| 28 | backend: MathBackendCPU, |
| 29 | attrs: StridedSliceAttrs |
| 30 | }): TensorInfo { |
| 31 | const {inputs, backend, attrs} = args; |
| 32 | const {x} = inputs; |
| 33 | const { |
| 34 | begin, |
| 35 | end, |
| 36 | strides, |
| 37 | beginMask, |
| 38 | endMask, |
| 39 | ellipsisMask, |
| 40 | newAxisMask, |
| 41 | shrinkAxisMask |
| 42 | } = attrs; |
| 43 | |
| 44 | assertNotComplex(x, 'stridedSlice'); |
| 45 | |
| 46 | const { |
| 47 | finalShapeSparse, |
| 48 | finalShape, |
| 49 | isIdentity, |
| 50 | sliceDim0, |
| 51 | isSimpleSlice, |
| 52 | begin: $begin, |
| 53 | end: $end, |
| 54 | strides: $strides |
| 55 | } = |
| 56 | slice_util.sliceInfo( |
| 57 | x.shape, begin, end, strides, beginMask, endMask, ellipsisMask, |
| 58 | newAxisMask, shrinkAxisMask); |
| 59 | |
| 60 | let result; |
| 61 | |
| 62 | // ref: |
| 63 | // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/strided_slice_op.cc |
| 64 | if (isIdentity) { |
| 65 | // Optimization #1, slice is a no-op plus reshape |
| 66 | result = reshape({inputs: {x}, backend, attrs: {shape: finalShape}}); |
| 67 | } else if (sliceDim0 || isSimpleSlice) { |
| 68 | // Optimization #2, slice is memory contiguous (only occurs in dim 0) |
| 69 | util.assert( |
| 70 | x.shape.length >= 1, |
| 71 | () => `Input must have rank at least 1, got: ${x.shape.length}`); |
| 72 | |
| 73 | const size = slice_util.computeOutShape($begin, $end, $strides); |
| 74 | // To tolerate begin[0] > end[0] (a 0-output slice), we min(begin, end). |
| 75 | const sliced = slice({inputs: {x}, backend, attrs: {begin: $begin, size}}); |
| 76 | result = |
| 77 | reshape({inputs: {x: sliced}, backend, attrs: {shape: finalShape}}); |
| 78 | backend.disposeIntermediateTensorInfo(sliced); |
| 79 | } else { |
| 80 | const xBuf = backend.bufferSync<Rank, 'float32'>(x); |
| 81 | const outBuf = stridedSliceImpl(finalShapeSparse, xBuf, $strides, $begin); |
| 82 | |
| 83 | result = backend.makeTensorInfo(finalShape, outBuf.dtype, outBuf.values); |
no test coverage detected
searching dependent graphs…