| 20 | import {computeDispatch, flatDispatchLayout} from './webgpu_util'; |
| 21 | |
| 22 | export class ArgMinMaxProgram implements WebGPUProgram { |
| 23 | outputShape: number[]; |
| 24 | shaderKey: string; |
| 25 | dispatchLayout: {x: number[]}; |
| 26 | dispatch: [number, number, number]; |
| 27 | workgroupSize: [number, number, number] = [64, 1, 1]; |
| 28 | variableNames = ['x']; |
| 29 | uniforms = 'infinityValue : f32,'; |
| 30 | inputShape: number[]; |
| 31 | reductionFactor: number; |
| 32 | op: string; |
| 33 | size = true; |
| 34 | private type: string; |
| 35 | |
| 36 | constructor(inputShape: number[], axis: number, reduceType: 'min'|'max') { |
| 37 | const axes = [axis]; |
| 38 | |
| 39 | this.op = reduceType === 'min' ? '<' : '>'; |
| 40 | |
| 41 | // |outShape| is the shape with the removed axis |
| 42 | const [outputShape, reduceShape] = |
| 43 | backend_util.computeOutAndReduceShapes(inputShape, axes); |
| 44 | |
| 45 | this.outputShape = outputShape.length === 0 ? [1] : outputShape; |
| 46 | this.dispatchLayout = flatDispatchLayout(this.outputShape); |
| 47 | // The shared algorithm is mainly used for large reduce size. It fully |
| 48 | // utilizes the threads in one workgroup to do the reduction. However, |
| 49 | // when the reduce size is very small, it's better to use the plain |
| 50 | // algorithm to reduce the number of workgroups to speedup. The threthold |
| 51 | // can be further tuned. |
| 52 | if (util.sizeFromShape(reduceShape) < 32) { |
| 53 | this.type = 'plain'; |
| 54 | this.dispatch = computeDispatch( |
| 55 | this.dispatchLayout, this.outputShape, this.workgroupSize); |
| 56 | } else { |
| 57 | this.type = 'shared'; |
| 58 | // A work group only outputs a data, so we transfer [1, 1, 1] to compute |
| 59 | // dispatch size. |
| 60 | this.dispatch = |
| 61 | computeDispatch(this.dispatchLayout, this.outputShape, [1, 1, 1]); |
| 62 | } |
| 63 | |
| 64 | this.inputShape = inputShape; |
| 65 | this.shaderKey = `argMinMax_${this.op}_${this.type}`; |
| 66 | } |
| 67 | |
| 68 | getUserCode(): string { |
| 69 | const workgroupSizeX = this.workgroupSize[0]; |
| 70 | const getInputShapeLastDim = () => { |
| 71 | if (this.inputShape.length === 1) { |
| 72 | return 'uniforms.xShape'; |
| 73 | } else { |
| 74 | return `uniforms.xShape.${getCoordsXYZ(this.inputShape.length - 1)}`; |
| 75 | } |
| 76 | }; |
| 77 | |
| 78 | const splitOutputCoords = () => { |
| 79 | let snippet = ''; |
nothing calls this directly
no outgoing calls
no test coverage detected
searching dependent graphs…