MCPcopy
hub / github.com/tensorflow/tfjs / ArgMinMaxProgram

Class ArgMinMaxProgram

tfjs-backend-webgpu/src/argminmax_webgpu.ts:22–167  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

20import {computeDispatch, flatDispatchLayout} from './webgpu_util';
21
22export class ArgMinMaxProgram implements WebGPUProgram {
23 outputShape: number[];
24 shaderKey: string;
25 dispatchLayout: {x: number[]};
26 dispatch: [number, number, number];
27 workgroupSize: [number, number, number] = [64, 1, 1];
28 variableNames = ['x'];
29 uniforms = 'infinityValue : f32,';
30 inputShape: number[];
31 reductionFactor: number;
32 op: string;
33 size = true;
34 private type: string;
35
36 constructor(inputShape: number[], axis: number, reduceType: 'min'|'max') {
37 const axes = [axis];
38
39 this.op = reduceType === 'min' ? '<' : '>';
40
41 // |outShape| is the shape with the removed axis
42 const [outputShape, reduceShape] =
43 backend_util.computeOutAndReduceShapes(inputShape, axes);
44
45 this.outputShape = outputShape.length === 0 ? [1] : outputShape;
46 this.dispatchLayout = flatDispatchLayout(this.outputShape);
47 // The shared algorithm is mainly used for large reduce size. It fully
48 // utilizes the threads in one workgroup to do the reduction. However,
49 // when the reduce size is very small, it's better to use the plain
50 // algorithm to reduce the number of workgroups to speedup. The threthold
51 // can be further tuned.
52 if (util.sizeFromShape(reduceShape) < 32) {
53 this.type = 'plain';
54 this.dispatch = computeDispatch(
55 this.dispatchLayout, this.outputShape, this.workgroupSize);
56 } else {
57 this.type = 'shared';
58 // A work group only outputs a data, so we transfer [1, 1, 1] to compute
59 // dispatch size.
60 this.dispatch =
61 computeDispatch(this.dispatchLayout, this.outputShape, [1, 1, 1]);
62 }
63
64 this.inputShape = inputShape;
65 this.shaderKey = `argMinMax_${this.op}_${this.type}`;
66 }
67
68 getUserCode(): string {
69 const workgroupSizeX = this.workgroupSize[0];
70 const getInputShapeLastDim = () => {
71 if (this.inputShape.length === 1) {
72 return 'uniforms.xShape';
73 } else {
74 return `uniforms.xShape.${getCoordsXYZ(this.inputShape.length - 1)}`;
75 }
76 };
77
78 const splitOutputCoords = () => {
79 let snippet = '';

Callers

nothing calls this directly

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…