MCPcopy
hub / github.com/tensorflow/tfjs / ReduceProgram

Class ReduceProgram

tfjs-backend-webgpu/src/reduce_webgpu.ts:22–145  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

20import {computeDispatch, flatDispatchLayout} from './webgpu_util';
21
22export class ReduceProgram implements WebGPUProgram {
23 outputShape: number[];
24 shaderKey: string;
25 dispatchLayout: {x: number[]};
26 dispatch: [number, number, number];
27 workgroupSize: [number, number, number];
28 variableNames = ['x'];
29 uniforms = 'reduceSize : i32,';
30 reduceType: 'all'|'any'|'max'|'mean'|'min'|'prod'|'sum';
31 inputShape: number[];
32 size = true;
33
34 constructor(
35 reduceInfo: backend_util.ReduceInfo,
36 reduceType: 'all'|'any'|'max'|'mean'|'min'|'prod'|'sum',
37 maxComputeWorkgroupSizeX: number) {
38 this.inputShape = [reduceInfo.batchSize, reduceInfo.inSize];
39 const [outputShape, ] =
40 backend_util.computeOutAndReduceShapes(this.inputShape, [1]);
41 this.outputShape = outputShape.length === 0 ? [1] : outputShape;
42 // If reduceSize |reduceInfo.inSize| is very large, the I/O accessing will
43 // become the bottleneck. Increasing workgroupSize can reduce the times of
44 // accessing global memory. The threshold value is just to make sure the
45 // reduceSize is large enough for a bigger workgroupSize.
46 if (reduceInfo.inSize >= 32768 && maxComputeWorkgroupSizeX >= 512) {
47 this.workgroupSize = [512, 1, 1];
48 } else if (reduceInfo.inSize >= 4096) {
49 this.workgroupSize = [256, 1, 1];
50 } else {
51 this.workgroupSize = [64, 1, 1];
52 }
53 this.dispatchLayout = flatDispatchLayout(this.outputShape);
54 // A work group only outputs a data, so we transfer [1, 1, 1] to compute
55 // dispatch size.
56 this.dispatch =
57 computeDispatch(this.dispatchLayout, this.outputShape, [1, 1, 1]);
58
59 this.reduceType = reduceType;
60 this.shaderKey = `reduce_${reduceType}`;
61 }
62
63 getUserCode(): string {
64 let reduceOp = ``;
65 let initValue = '0.0';
66 const workgroupSizeX = this.workgroupSize[0];
67 if (this.reduceType === 'min' || this.reduceType === 'max') {
68 reduceOp = `
69 if (isnan(candidate)) {
70 bestValue = uniforms.NAN;
71 } else if (!isnan(bestValue) && candidate ${
72 this.reduceType === 'min' ? '<' : '>'} bestValue)
73 { bestValue = candidate; }`;
74 initValue = 'f32(x[offset])';
75 } else if (this.reduceType === 'sum' || this.reduceType === 'mean') {
76 reduceOp = ' bestValue = bestValue + candidate; ';
77 } else if (this.reduceType === 'prod') {
78 reduceOp = ' bestValue = bestValue * candidate; ';
79 initValue = '1.0';

Callers

nothing calls this directly

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…