(args: {inputs: SelectInputs, backend: MathBackendCPU})
| 21 | import {assertNotComplex} from '../cpu_util'; |
| 22 | |
| 23 | export function select(args: {inputs: SelectInputs, backend: MathBackendCPU}): |
| 24 | TensorInfo { |
| 25 | const {inputs, backend} = args; |
| 26 | const {condition, t, e} = inputs; |
| 27 | |
| 28 | assertNotComplex([condition, t, e], 'select'); |
| 29 | const conditionRank = condition.shape.length; |
| 30 | |
| 31 | const values = backend.data.get(condition.dataId).values as TypedArray; |
| 32 | const tValues = backend.data.get(t.dataId).values as TypedArray; |
| 33 | const eValues = backend.data.get(e.dataId).values as TypedArray; |
| 34 | const resultDtype = upcastType(t.dtype, e.dtype); |
| 35 | const newValues = |
| 36 | util.makeZerosTypedArray(util.sizeFromShape(t.shape), resultDtype); |
| 37 | |
| 38 | let index = 0; |
| 39 | const offset = |
| 40 | conditionRank === 0 || conditionRank > 1 || t.shape.length === 1 ? |
| 41 | 1 : |
| 42 | util.sizeFromShape(t.shape.slice(1)); |
| 43 | |
| 44 | for (let i = 0; i < values.length; i++) { |
| 45 | for (let j = 0; j < offset; j++) { |
| 46 | if (values[i] === 1) { |
| 47 | newValues[index++] = tValues[i]; |
| 48 | } else { |
| 49 | newValues[index++] = eValues[i]; |
| 50 | } |
| 51 | } |
| 52 | } |
| 53 | |
| 54 | return backend.makeTensorInfo(t.shape, resultDtype, newValues); |
| 55 | } |
| 56 | |
| 57 | export const selectConfig: KernelConfig = { |
| 58 | kernelName: Select, |
nothing calls this directly
no test coverage detected
searching dependent graphs…