| 19 | import {computeDispatch, flatDispatchLayout} from './webgpu_util'; |
| 20 | |
| 21 | export class StridedSliceProgram implements WebGPUProgram { |
| 22 | variableNames = ['x']; |
| 23 | uniforms: string; |
| 24 | outputShape: number[]; |
| 25 | shaderKey: string; |
| 26 | dispatchLayout: {x: number[]}; |
| 27 | dispatch: [number, number, number]; |
| 28 | // TODO(xing.xu): Increase the workPerThread. |
| 29 | workPerThread = 1; |
| 30 | workgroupSize: [number, number, number] = [64, 1, 1]; |
| 31 | size = true; |
| 32 | |
| 33 | constructor(destSize: number[]) { |
| 34 | this.outputShape = destSize; |
| 35 | this.dispatchLayout = flatDispatchLayout(this.outputShape); |
| 36 | this.dispatch = computeDispatch( |
| 37 | this.dispatchLayout, this.outputShape, this.workgroupSize, |
| 38 | [this.workPerThread, 1, 1]); |
| 39 | |
| 40 | const dtype = getCoordsDataType(this.outputShape.length); |
| 41 | this.uniforms = `begin : ${dtype}, strides : ${dtype}, `; |
| 42 | this.shaderKey = 'stridedSlice'; |
| 43 | } |
| 44 | |
| 45 | getUserCode(): string { |
| 46 | const rank = this.outputShape.length; |
| 47 | let newCoords = ''; |
| 48 | if (rank === 1) { |
| 49 | newCoords = 'coords * uniforms.strides + uniforms.begin'; |
| 50 | } else { |
| 51 | let outputAxis = 0; |
| 52 | newCoords = |
| 53 | this.outputShape |
| 54 | .map((_, i) => { |
| 55 | outputAxis++; |
| 56 | return this.outputShape.length === 1 ? |
| 57 | `coords * uniforms.strides[${i}] + uniforms.begin[${i}]` : |
| 58 | `coords[${outputAxis - 1}] * uniforms.strides[${ |
| 59 | i}] + uniforms.begin[${i}]`; |
| 60 | }) |
| 61 | .join(','); |
| 62 | } |
| 63 | |
| 64 | const userCode = ` |
| 65 | ${main('index')} { |
| 66 | if (index < uniforms.size) { |
| 67 | let coords = getCoordsFromIndex(index); |
| 68 | setOutputAtIndex(index, getX(${newCoords})); |
| 69 | } |
| 70 | } |
| 71 | `; |
| 72 | return userCode; |
| 73 | } |
| 74 | } |
nothing calls this directly
no outgoing calls
no test coverage detected
searching dependent graphs…