MCPcopy Index your code
hub / github.com/tensorflow/tfjs / reduce

Function reduce

tfjs-backend-webgpu/src/kernel_utils/reduce.ts:35–103  ·  view source on GitHub ↗
(
    x: TensorInfo, axis: number|number[], keepDims: boolean,
    reduceType: ReduceTypes, backend: WebGPUBackend)

Source from the content-addressed store, hash-verified

33};
34
35export function reduce(
36 x: TensorInfo, axis: number|number[], keepDims: boolean,
37 reduceType: ReduceTypes, backend: WebGPUBackend): TensorInfo {
38 const xRank = x.shape.length;
39 const toDispose = [];
40
41 const origAxes = util.parseAxisParam(axis, x.shape);
42 let axes = origAxes;
43 const permutedAxes = backend_util.getAxesPermutation(axes, xRank);
44
45 let input = x;
46 if (permutedAxes != null) {
47 input = transpose({inputs: {x}, attrs: {perm: permutedAxes}, backend});
48 axes = backend_util.getInnerMostAxes(axes.length, xRank);
49 toDispose.push(input);
50 }
51
52 backend_util.assertAxesAreInnerMostDims(reduceType, axes, xRank);
53
54 const [reduceOutShape, reduceShape] =
55 backend_util.computeOutAndReduceShapes(input.shape, axes);
56 let resOutShape = reduceOutShape;
57 if (keepDims) {
58 // rather than reshape at the end, set the target shape here.
59 resOutShape = backend_util.expandShapeToKeepDim(reduceOutShape, origAxes);
60 }
61
62 let res;
63 if ((reduceType === 'max' || reduceType === 'prod') &&
64 backend.shouldExecuteOnCPU([input])) {
65 const xVals = backend.tensorMap.get(input.dataId).values as TypedArray;
66 switch (reduceType) {
67 case 'max':
68 const outValues = maxImplCPU(
69 xVals, util.sizeFromShape(reduceShape), resOutShape, x.dtype);
70 res = backend.makeTensorInfo(resOutShape, x.dtype, outValues);
71 break;
72 case 'prod':
73 const {outVals, outShape, outDtype} =
74 prodImplCPU(input.shape, input.dtype, xVals, axes);
75 res = backend.makeTensorInfo(outShape, outDtype, outVals);
76 break;
77 default:
78 throw new Error(
79 `${reduceType} CPU implementation is not yet supported.`);
80 }
81 } else {
82 const inSize = util.sizeFromShape(reduceShape);
83 const xSize = util.sizeFromShape(input.shape);
84 const batchSize = xSize / inSize;
85
86 const reduceInfo = {windowSize: inSize, inSize, batchSize, outSize: 1};
87 const dtype = RETURN_TYPES[reduceType] || sumOutType(x.dtype);
88 const uniformData = [
89 {type: 'int32', data: [inSize]},
90 ];
91 const program = new ReduceProgram(
92 reduceInfo, reduceType, backend.device.limits.maxComputeWorkgroupSizeX);

Callers 7

allFunction · 0.90
anyFunction · 0.90
minFunction · 0.90
meanFunction · 0.90
prodFunction · 0.90
maxFunction · 0.90
sumFunction · 0.90

Calls 9

transposeFunction · 0.90
sumOutTypeFunction · 0.90
reshapeFunction · 0.90
runWebGPUProgramMethod · 0.80
disposeDataMethod · 0.65
pushMethod · 0.45
shouldExecuteOnCPUMethod · 0.45
getMethod · 0.45
makeTensorInfoMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…