MCPcopy
hub / github.com/tensorflow/tfjs / GatherProgram

Class GatherProgram

tfjs-backend-webgpu/src/gather_webgpu.ts:21–55  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

19import {computeDispatch, flatDispatchLayout} from './webgpu_util';
20
21export class GatherProgram implements WebGPUProgram {
22 outputShape: number[];
23 shaderKey: string;
24 dispatchLayout: {x: number[]};
25 dispatch: [number, number, number];
26 variableNames: string[] = ['A', 'indices'];
27 workgroupSize: [number, number, number] = [64, 1, 1];
28 aShape: number[];
29 size = true;
30
31 constructor(aShape: number[], outputShape: number[]) {
32 this.outputShape = aShape.slice();
33 this.aShape = aShape;
34 this.outputShape = outputShape;
35 this.dispatchLayout = flatDispatchLayout(this.outputShape);
36 this.dispatch = computeDispatch(
37 this.dispatchLayout, this.outputShape, this.workgroupSize);
38 this.shaderKey = `gather`;
39 }
40
41 getUserCode(): string {
42 const sourceCoords = getSourceCoords(this.aShape);
43 const userCode = `
44 ${main('index')} {
45 if (index < uniforms.size) {
46 let resRC = getCoordsFromIndex(index);
47 let indexZ = i32(getIndices(resRC.x, resRC.z));
48 let inBounds = select(0.0, 1.0, indexZ >= 0 && indexZ < uniforms.aShape[2]);
49 setOutputAtIndex(index, inBounds * getA(${sourceCoords}));
50 }
51 }
52 `;
53 return userCode;
54 }
55}
56
57// The input and output are always flattened into rank 4 tensors.
58function getSourceCoords(aShape: number[]): string {

Callers

nothing calls this directly

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…