MCPcopy Index your code
hub / github.com/tensorflow/tfjs / StridedSliceProgram

Class StridedSliceProgram

tfjs-backend-webgpu/src/strided_slice_webgpu.ts:21–74  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

19import {computeDispatch, flatDispatchLayout} from './webgpu_util';
20
21export class StridedSliceProgram implements WebGPUProgram {
22 variableNames = ['x'];
23 uniforms: string;
24 outputShape: number[];
25 shaderKey: string;
26 dispatchLayout: {x: number[]};
27 dispatch: [number, number, number];
28 // TODO(xing.xu): Increase the workPerThread.
29 workPerThread = 1;
30 workgroupSize: [number, number, number] = [64, 1, 1];
31 size = true;
32
33 constructor(destSize: number[]) {
34 this.outputShape = destSize;
35 this.dispatchLayout = flatDispatchLayout(this.outputShape);
36 this.dispatch = computeDispatch(
37 this.dispatchLayout, this.outputShape, this.workgroupSize,
38 [this.workPerThread, 1, 1]);
39
40 const dtype = getCoordsDataType(this.outputShape.length);
41 this.uniforms = `begin : ${dtype}, strides : ${dtype}, `;
42 this.shaderKey = 'stridedSlice';
43 }
44
45 getUserCode(): string {
46 const rank = this.outputShape.length;
47 let newCoords = '';
48 if (rank === 1) {
49 newCoords = 'coords * uniforms.strides + uniforms.begin';
50 } else {
51 let outputAxis = 0;
52 newCoords =
53 this.outputShape
54 .map((_, i) => {
55 outputAxis++;
56 return this.outputShape.length === 1 ?
57 `coords * uniforms.strides[${i}] + uniforms.begin[${i}]` :
58 `coords[${outputAxis - 1}] * uniforms.strides[${
59 i}] + uniforms.begin[${i}]`;
60 })
61 .join(',');
62 }
63
64 const userCode = `
65 ${main('index')} {
66 if (index < uniforms.size) {
67 let coords = getCoordsFromIndex(index);
68 setOutputAtIndex(index, getX(${newCoords}));
69 }
70 }
71 `;
72 return userCode;
73 }
74}

Callers

nothing calls this directly

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…