| 20 | import {computeDispatch, flatDispatchLayout} from './webgpu_util'; |
| 21 | |
| 22 | export class ReduceProgram implements WebGPUProgram { |
| 23 | outputShape: number[]; |
| 24 | shaderKey: string; |
| 25 | dispatchLayout: {x: number[]}; |
| 26 | dispatch: [number, number, number]; |
| 27 | workgroupSize: [number, number, number]; |
| 28 | variableNames = ['x']; |
| 29 | uniforms = 'reduceSize : i32,'; |
| 30 | reduceType: 'all'|'any'|'max'|'mean'|'min'|'prod'|'sum'; |
| 31 | inputShape: number[]; |
| 32 | size = true; |
| 33 | |
| 34 | constructor( |
| 35 | reduceInfo: backend_util.ReduceInfo, |
| 36 | reduceType: 'all'|'any'|'max'|'mean'|'min'|'prod'|'sum', |
| 37 | maxComputeWorkgroupSizeX: number) { |
| 38 | this.inputShape = [reduceInfo.batchSize, reduceInfo.inSize]; |
| 39 | const [outputShape, ] = |
| 40 | backend_util.computeOutAndReduceShapes(this.inputShape, [1]); |
| 41 | this.outputShape = outputShape.length === 0 ? [1] : outputShape; |
| 42 | // If reduceSize |reduceInfo.inSize| is very large, the I/O accessing will |
| 43 | // become the bottleneck. Increasing workgroupSize can reduce the times of |
| 44 | // accessing global memory. The threshold value is just to make sure the |
| 45 | // reduceSize is large enough for a bigger workgroupSize. |
| 46 | if (reduceInfo.inSize >= 32768 && maxComputeWorkgroupSizeX >= 512) { |
| 47 | this.workgroupSize = [512, 1, 1]; |
| 48 | } else if (reduceInfo.inSize >= 4096) { |
| 49 | this.workgroupSize = [256, 1, 1]; |
| 50 | } else { |
| 51 | this.workgroupSize = [64, 1, 1]; |
| 52 | } |
| 53 | this.dispatchLayout = flatDispatchLayout(this.outputShape); |
| 54 | // A work group only outputs a data, so we transfer [1, 1, 1] to compute |
| 55 | // dispatch size. |
| 56 | this.dispatch = |
| 57 | computeDispatch(this.dispatchLayout, this.outputShape, [1, 1, 1]); |
| 58 | |
| 59 | this.reduceType = reduceType; |
| 60 | this.shaderKey = `reduce_${reduceType}`; |
| 61 | } |
| 62 | |
| 63 | getUserCode(): string { |
| 64 | let reduceOp = ``; |
| 65 | let initValue = '0.0'; |
| 66 | const workgroupSizeX = this.workgroupSize[0]; |
| 67 | if (this.reduceType === 'min' || this.reduceType === 'max') { |
| 68 | reduceOp = ` |
| 69 | if (isnan(candidate)) { |
| 70 | bestValue = uniforms.NAN; |
| 71 | } else if (!isnan(bestValue) && candidate ${ |
| 72 | this.reduceType === 'min' ? '<' : '>'} bestValue) |
| 73 | { bestValue = candidate; }`; |
| 74 | initValue = 'f32(x[offset])'; |
| 75 | } else if (this.reduceType === 'sum' || this.reduceType === 'mean') { |
| 76 | reduceOp = ' bestValue = bestValue + candidate; '; |
| 77 | } else if (this.reduceType === 'prod') { |
| 78 | reduceOp = ' bestValue = bestValue * candidate; '; |
| 79 | initValue = '1.0'; |
nothing calls this directly
no outgoing calls
no test coverage detected
searching dependent graphs…