(
x: TensorInfo, dtype: DataType, reductionType: ReduceTypes,
backend: MathBackendWebGL)
| 44 | } |
| 45 | |
| 46 | export function reduce( |
| 47 | x: TensorInfo, dtype: DataType, reductionType: ReduceTypes, |
| 48 | backend: MathBackendWebGL): TensorInfo { |
| 49 | const reductionStages = getReductionStages(x.shape); |
| 50 | |
| 51 | let result = x; |
| 52 | for (let i = 0; i < reductionStages.length; i++) { |
| 53 | const {inSize, windowSize, outSize} = reductionStages[i]; |
| 54 | |
| 55 | let program: ReduceProgram|MeanProgram; |
| 56 | let previousResult: TensorInfo; |
| 57 | if (reductionType === 'mean') { |
| 58 | program = i === 0 ? |
| 59 | new MeanProgram( |
| 60 | {windowSize, inSize, batchSize: x.shape[0], outSize}, inSize) : |
| 61 | new MeanProgram({windowSize, inSize, batchSize: x.shape[0], outSize}); |
| 62 | } else { |
| 63 | program = new ReduceProgram( |
| 64 | {windowSize, inSize, batchSize: x.shape[0], outSize}, reductionType); |
| 65 | } |
| 66 | |
| 67 | previousResult = result; |
| 68 | result = backend.runWebGLProgram(program, [result], dtype); |
| 69 | |
| 70 | if (previousResult.dataId !== x.dataId) { |
| 71 | backend.disposeIntermediateTensorInfo(previousResult); |
| 72 | } |
| 73 | } |
| 74 | |
| 75 | return result; |
| 76 | } |
no test coverage detected
searching dependent graphs…