(node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext,
ops = tfOps)
| 27 | |
| 28 | export const executeOp: InternalOpExecutor = |
| 29 | (node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext, |
| 30 | ops = tfOps): Tensor[] => { |
| 31 | switch (node.op) { |
| 32 | case 'ConcatV2': |
| 33 | case 'Concat': { |
| 34 | const n = getParamValue('n', node, tensorMap, context) as number; |
| 35 | const axis = |
| 36 | getParamValue('axis', node, tensorMap, context) as number; |
| 37 | let inputs = |
| 38 | getParamValue('tensors', node, tensorMap, context) as Tensor[]; |
| 39 | inputs = inputs.slice(0, n); |
| 40 | return [ops.concat(inputs, axis)]; |
| 41 | } |
| 42 | case 'Gather': { |
| 43 | const input = getParamValue('x', node, tensorMap, context) as Tensor; |
| 44 | const indices = |
| 45 | getParamValue('indices', node, tensorMap, context) as Tensor1D; |
| 46 | return [ops.gather(input, ops.cast(indices, 'int32'), 0)]; |
| 47 | } |
| 48 | case 'GatherV2': { |
| 49 | const axis = |
| 50 | getParamValue('axis', node, tensorMap, context) as number; |
| 51 | const batchDims = |
| 52 | getParamValue('batchDims', node, tensorMap, context) as number; |
| 53 | const input = getParamValue('x', node, tensorMap, context) as Tensor; |
| 54 | const indices = |
| 55 | getParamValue('indices', node, tensorMap, context) as Tensor1D; |
| 56 | return [ops.gather( |
| 57 | input, ops.cast(indices, 'int32'), axis, batchDims)]; |
| 58 | } |
| 59 | case 'Reverse': { |
| 60 | const dims = |
| 61 | getParamValue('dims', node, tensorMap, context) as boolean[]; |
| 62 | const axis = []; |
| 63 | for (let i = 0; i < dims.length; i++) { |
| 64 | if (dims[i]) { |
| 65 | axis.push(i); |
| 66 | } |
| 67 | } |
| 68 | const input = getParamValue('x', node, tensorMap, context) as Tensor; |
| 69 | return [ops.reverse(input, axis)]; |
| 70 | } |
| 71 | case 'ReverseV2': { |
| 72 | const axis = |
| 73 | getParamValue('axis', node, tensorMap, context) as number[]; |
| 74 | const input = getParamValue('x', node, tensorMap, context) as Tensor; |
| 75 | return [ops.reverse(input, axis)]; |
| 76 | } |
| 77 | case 'Slice': { |
| 78 | // tslint:disable-next-line:no-any |
| 79 | const begin = getParamValue('begin', node, tensorMap, context) as any; |
| 80 | // tslint:disable-next-line:no-any |
| 81 | const size = getParamValue('size', node, tensorMap, context) as any; |
| 82 | return [ops.slice( |
| 83 | getParamValue('x', node, tensorMap, context) as Tensor, begin, |
| 84 | size)]; |
| 85 | } |
| 86 | case 'StridedSlice': { |
no test coverage detected
searching dependent graphs…