(op: SimpleBinaryOperation)
| 23 | * Template that creates implementation for binary ops. Supports broadcast. |
| 24 | */ |
| 25 | export function createSimpleBinaryKernelImpl(op: SimpleBinaryOperation): |
| 26 | SimpleBinaryKernelImpl { |
| 27 | return (aShape: number[], bShape: number[], aVals: DataValues, |
| 28 | bVals: DataValues, dtype: DataType): [TypedArray, number[]] => { |
| 29 | const newShape = backend_util.assertAndGetBroadcastShape(aShape, bShape); |
| 30 | |
| 31 | const resultRank = newShape.length; |
| 32 | const resultStrides = util.computeStrides(newShape); |
| 33 | const resultSize = util.sizeFromShape(newShape); |
| 34 | |
| 35 | const result = |
| 36 | util.getTypedArrayFromDType(dtype as NumericDataType, resultSize); |
| 37 | |
| 38 | const aRank = aShape.length; |
| 39 | const bRank = bShape.length; |
| 40 | |
| 41 | const aStrides = util.computeStrides(aShape); |
| 42 | const bStrides = util.computeStrides(bShape); |
| 43 | |
| 44 | const aBroadcastDims = backend_util.getBroadcastDims(aShape, newShape); |
| 45 | const bBroadcastDims = backend_util.getBroadcastDims(bShape, newShape); |
| 46 | |
| 47 | if (aBroadcastDims.length + bBroadcastDims.length === 0) { |
| 48 | for (let i = 0; i < result.length; ++i) { |
| 49 | result[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]); |
| 50 | } |
| 51 | } else { |
| 52 | for (let i = 0; i < result.length; ++i) { |
| 53 | const loc = util.indexToLoc(i, resultRank, resultStrides); |
| 54 | |
| 55 | const aLoc = loc.slice(-aRank); |
| 56 | aBroadcastDims.forEach(d => aLoc[d] = 0); |
| 57 | const aIndex = util.locToIndex(aLoc, aRank, aStrides); |
| 58 | |
| 59 | const bLoc = loc.slice(-bRank); |
| 60 | bBroadcastDims.forEach(d => bLoc[d] = 0); |
| 61 | const bIndex = util.locToIndex(bLoc, bRank, bStrides); |
| 62 | |
| 63 | result[i] = op(aVals[aIndex], bVals[bIndex]); |
| 64 | } |
| 65 | } |
| 66 | |
| 67 | return [result, newShape]; |
| 68 | }; |
| 69 | } |
no test coverage detected
searching dependent graphs…