(
shape: number[], isChannelsLast: boolean)
| 44 | // This function computes the target shape for fusing height and width |
| 45 | // dimensions. Returning null means the shape is already compatible. |
| 46 | function getShapeForBatchMatMul( |
| 47 | shape: number[], isChannelsLast: boolean): number[] { |
| 48 | const length = shape.length; |
| 49 | if (length >= 3) { |
| 50 | return isChannelsLast ? |
| 51 | [ |
| 52 | ...shape.slice(0, -3) /* batch */, |
| 53 | shape[length - 3] * shape[length - 2] /* height * width */, |
| 54 | shape[length - 1] /* channel */ |
| 55 | ] : |
| 56 | [ |
| 57 | ...shape.slice(0, -3) /* batch */, shape[length - 3] /* channel */, |
| 58 | shape[length - 2] * shape[length - 1] /* height * width */ |
| 59 | ]; |
| 60 | } else if (!isChannelsLast && length === 1 && shape[0] > 1) { |
| 61 | return [shape[0], 1]; |
| 62 | } else { |
| 63 | return null; |
| 64 | } |
| 65 | } |
| 66 | |
| 67 | // For 1x1 kernels that iterate through every point in the input, convolution |
| 68 | // can be expressed as matrix multiplication (without need for memory |
no test coverage detected
searching dependent graphs…