({
a,
b,
transposeA,
transposeB,
backend,
bias = null,
preluActivationWeights = null,
leakyreluAlpha = 0,
activation = null
}: BatchMatMulConfig)
| 41 | }; |
| 42 | |
| 43 | export function batchMatMulImpl({ |
| 44 | a, |
| 45 | b, |
| 46 | transposeA, |
| 47 | transposeB, |
| 48 | backend, |
| 49 | bias = null, |
| 50 | preluActivationWeights = null, |
| 51 | leakyreluAlpha = 0, |
| 52 | activation = null |
| 53 | }: BatchMatMulConfig): TensorInfo { |
| 54 | const aRank = a.shape.length; |
| 55 | const bRank = b.shape.length; |
| 56 | |
| 57 | const innerShapeA = transposeA ? a.shape[aRank - 2] : a.shape[aRank - 1]; |
| 58 | const innerShapeB = transposeB ? b.shape[bRank - 1] : b.shape[bRank - 2]; |
| 59 | |
| 60 | const outerShapeA = transposeA ? a.shape[aRank - 1] : a.shape[aRank - 2]; |
| 61 | const outerShapeB = transposeB ? b.shape[bRank - 2] : b.shape[bRank - 1]; |
| 62 | |
| 63 | const outerDimsA = a.shape.slice(0, -2); |
| 64 | const outerDimsB = b.shape.slice(0, -2); |
| 65 | |
| 66 | const batchDimA = util.sizeFromShape(outerDimsA); |
| 67 | const batchDimB = util.sizeFromShape(outerDimsB); |
| 68 | |
| 69 | const outShapeOuterDims = broadcast_util.assertAndGetBroadcastShape( |
| 70 | a.shape.slice(0, -2), b.shape.slice(0, -2)); |
| 71 | const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]); |
| 72 | |
| 73 | util.assert( |
| 74 | innerShapeA === innerShapeB, |
| 75 | () => `Error in matMul: inner shapes (${innerShapeA}) and (` + |
| 76 | `${innerShapeB}) of Tensors with shapes ${a.shape} and ` + |
| 77 | `${b.shape} and transposeA=${transposeA}` + |
| 78 | ` and transposeB=${transposeB} must match.`); |
| 79 | |
| 80 | const a3dShape: [number, number, number] = transposeA ? |
| 81 | [batchDimA, innerShapeA, outerShapeA] : |
| 82 | [batchDimA, outerShapeA, innerShapeA]; |
| 83 | const b3dShape: [number, number, number] = transposeB ? |
| 84 | [batchDimB, outerShapeB, innerShapeB] : |
| 85 | [batchDimB, innerShapeB, outerShapeB]; |
| 86 | |
| 87 | // The rest of the implementation is designed to operate on rank-3 tensors |
| 88 | const a3d = reshape({inputs: {x: a}, backend, attrs: {shape: a3dShape}}); |
| 89 | const b3d = reshape({inputs: {x: b}, backend, attrs: {shape: b3dShape}}); |
| 90 | const intermediates: TensorInfo[] = [a3d, b3d]; |
| 91 | |
| 92 | const batchDim = Math.max(batchDimA, batchDimB); |
| 93 | |
| 94 | const inputs: TensorInfo[] = [a3d, b3d]; |
| 95 | const dimensions = [ |
| 96 | {type: 'int32', data: [outerShapeA]}, {type: 'int32', data: [outerShapeB]}, |
| 97 | {type: 'int32', data: [innerShapeA]} |
| 98 | ]; |
| 99 | |
| 100 | let program: WebGPUProgram; |
no test coverage detected
searching dependent graphs…