| 19 | import {computeDispatch, flatDispatchLayout} from './webgpu_util'; |
| 20 | |
| 21 | export class GatherProgram implements WebGPUProgram { |
| 22 | outputShape: number[]; |
| 23 | shaderKey: string; |
| 24 | dispatchLayout: {x: number[]}; |
| 25 | dispatch: [number, number, number]; |
| 26 | variableNames: string[] = ['A', 'indices']; |
| 27 | workgroupSize: [number, number, number] = [64, 1, 1]; |
| 28 | aShape: number[]; |
| 29 | size = true; |
| 30 | |
| 31 | constructor(aShape: number[], outputShape: number[]) { |
| 32 | this.outputShape = aShape.slice(); |
| 33 | this.aShape = aShape; |
| 34 | this.outputShape = outputShape; |
| 35 | this.dispatchLayout = flatDispatchLayout(this.outputShape); |
| 36 | this.dispatch = computeDispatch( |
| 37 | this.dispatchLayout, this.outputShape, this.workgroupSize); |
| 38 | this.shaderKey = `gather`; |
| 39 | } |
| 40 | |
| 41 | getUserCode(): string { |
| 42 | const sourceCoords = getSourceCoords(this.aShape); |
| 43 | const userCode = ` |
| 44 | ${main('index')} { |
| 45 | if (index < uniforms.size) { |
| 46 | let resRC = getCoordsFromIndex(index); |
| 47 | let indexZ = i32(getIndices(resRC.x, resRC.z)); |
| 48 | let inBounds = select(0.0, 1.0, indexZ >= 0 && indexZ < uniforms.aShape[2]); |
| 49 | setOutputAtIndex(index, inBounds * getA(${sourceCoords})); |
| 50 | } |
| 51 | } |
| 52 | `; |
| 53 | return userCode; |
| 54 | } |
| 55 | } |
| 56 | |
| 57 | // The input and output are always flattened into rank 4 tensors. |
| 58 | function getSourceCoords(aShape: number[]): string { |
nothing calls this directly
no outgoing calls
no test coverage detected
searching dependent graphs…