MCPcopy Index your code
hub / github.com/tensorflow/tfjs / getShapeForBatchMatMul

Function getShapeForBatchMatMul

tfjs-backend-webgpu/src/kernels/Conv2D_impl.ts:46–65  ·  view source on GitHub ↗
(
    shape: number[], isChannelsLast: boolean)

Source from the content-addressed store, hash-verified

44// This function computes the target shape for fusing height and width
45// dimensions. Returning null means the shape is already compatible.
46function getShapeForBatchMatMul(
47 shape: number[], isChannelsLast: boolean): number[] {
48 const length = shape.length;
49 if (length >= 3) {
50 return isChannelsLast ?
51 [
52 ...shape.slice(0, -3) /* batch */,
53 shape[length - 3] * shape[length - 2] /* height * width */,
54 shape[length - 1] /* channel */
55 ] :
56 [
57 ...shape.slice(0, -3) /* batch */, shape[length - 3] /* channel */,
58 shape[length - 2] * shape[length - 1] /* height * width */
59 ];
60 } else if (!isChannelsLast && length === 1 && shape[0] > 1) {
61 return [shape[0], 1];
62 } else {
63 return null;
64 }
65}
66
67// For 1x1 kernels that iterate through every point in the input, convolution
68// can be expressed as matrix multiplication (without need for memory

Callers 2

conv2dByMatMulFunction · 0.70
conv2dWithIm2ColFunction · 0.70

Calls 1

sliceMethod · 0.65

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…