MCPcopy Index your code
hub / github.com/tensorflow/tfjs / batchMatMulImpl

Function batchMatMulImpl

tfjs-backend-webgpu/src/kernels/BatchMatMul_impl.ts:43–218  ·  view source on GitHub ↗
({
  a,
  b,
  transposeA,
  transposeB,
  backend,
  bias = null,
  preluActivationWeights = null,
  leakyreluAlpha = 0,
  activation = null
}: BatchMatMulConfig)

Source from the content-addressed store, hash-verified

41};
42
43export function batchMatMulImpl({
44 a,
45 b,
46 transposeA,
47 transposeB,
48 backend,
49 bias = null,
50 preluActivationWeights = null,
51 leakyreluAlpha = 0,
52 activation = null
53}: BatchMatMulConfig): TensorInfo {
54 const aRank = a.shape.length;
55 const bRank = b.shape.length;
56
57 const innerShapeA = transposeA ? a.shape[aRank - 2] : a.shape[aRank - 1];
58 const innerShapeB = transposeB ? b.shape[bRank - 1] : b.shape[bRank - 2];
59
60 const outerShapeA = transposeA ? a.shape[aRank - 1] : a.shape[aRank - 2];
61 const outerShapeB = transposeB ? b.shape[bRank - 2] : b.shape[bRank - 1];
62
63 const outerDimsA = a.shape.slice(0, -2);
64 const outerDimsB = b.shape.slice(0, -2);
65
66 const batchDimA = util.sizeFromShape(outerDimsA);
67 const batchDimB = util.sizeFromShape(outerDimsB);
68
69 const outShapeOuterDims = broadcast_util.assertAndGetBroadcastShape(
70 a.shape.slice(0, -2), b.shape.slice(0, -2));
71 const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
72
73 util.assert(
74 innerShapeA === innerShapeB,
75 () => `Error in matMul: inner shapes (${innerShapeA}) and (` +
76 `${innerShapeB}) of Tensors with shapes ${a.shape} and ` +
77 `${b.shape} and transposeA=${transposeA}` +
78 ` and transposeB=${transposeB} must match.`);
79
80 const a3dShape: [number, number, number] = transposeA ?
81 [batchDimA, innerShapeA, outerShapeA] :
82 [batchDimA, outerShapeA, innerShapeA];
83 const b3dShape: [number, number, number] = transposeB ?
84 [batchDimB, outerShapeB, innerShapeB] :
85 [batchDimB, innerShapeB, outerShapeB];
86
87 // The rest of the implementation is designed to operate on rank-3 tensors
88 const a3d = reshape({inputs: {x: a}, backend, attrs: {shape: a3dShape}});
89 const b3d = reshape({inputs: {x: b}, backend, attrs: {shape: b3dShape}});
90 const intermediates: TensorInfo[] = [a3d, b3d];
91
92 const batchDim = Math.max(batchDimA, batchDimB);
93
94 const inputs: TensorInfo[] = [a3d, b3d];
95 const dimensions = [
96 {type: 'int32', data: [outerShapeA]}, {type: 'int32', data: [outerShapeB]},
97 {type: 'int32', data: [innerShapeA]}
98 ];
99
100 let program: WebGPUProgram;

Callers 4

_fusedMatMulFunction · 0.90
batchMatMulFunction · 0.90
conv2dByMatMulFunction · 0.90
conv2dWithIm2ColFunction · 0.90

Calls 13

reshapeFunction · 0.90
envFunction · 0.90
fillFunction · 0.90
maxMethod · 0.80
getNumberMethod · 0.80
ceilMethod · 0.80
runWebGPUProgramMethod · 0.80
isIntelMethod · 0.80
sliceMethod · 0.65
concatMethod · 0.65
disposeDataMethod · 0.65
getMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…