MCPcopy Index your code
hub / github.com/tensorflow/tfjs / conv2dWithIm2Col

Function conv2dWithIm2Col

tfjs-backend-webgpu/src/kernels/Conv2D_impl.ts:175–273  ·  view source on GitHub ↗
({
  x,
  filter,
  convInfo,
  backend,
  bias = null,
  preluActivationWeights = null,
  leakyreluAlpha = 0,
  activation = null
}: Conv2DConfig)

Source from the content-addressed store, hash-verified

173// Implements the im2col algorithm as outlined in "High Performance
174// Convolutional Neural Networks for Document Processing" (Suvisoft, 2006)
175function conv2dWithIm2Col({
176 x,
177 filter,
178 convInfo,
179 backend,
180 bias = null,
181 preluActivationWeights = null,
182 leakyreluAlpha = 0,
183 activation = null
184}: Conv2DConfig) {
185 // Rearranges conv2d input so each block to be convolved over forms the
186 // row of a new matrix with shape [outHeight * outWidth,
187 // filterWidth * filterHeight * inChannels]. The filter is also rearranged so
188 // each output channel forms a col of a new matrix with shape [
189 // filterWidth * filterHeight * inChannels, outChannels]. The convolution is
190 // then computed by multiplying these matrices and reshaping the result.
191 const {
192 filterWidth,
193 filterHeight,
194 inChannels,
195 strideWidth,
196 strideHeight,
197 padInfo,
198 outWidth,
199 outHeight,
200 dilationWidth,
201 dilationHeight,
202 dataFormat
203 } = convInfo;
204
205 const isChannelsLast = dataFormat === 'channelsLast';
206
207 const sharedDim = filterWidth * filterHeight * inChannels;
208 const numCols = outHeight * outWidth;
209 const x2ColShape = isChannelsLast ? [convInfo.batchSize, numCols, sharedDim] :
210 [convInfo.batchSize, sharedDim, numCols];
211
212 const im2ColProgram = new Im2ColProgram(x2ColShape, isChannelsLast);
213 const dimensions = [
214 {type: 'int32', data: [padInfo.top, padInfo.left]}, // Padding.
215 {type: 'int32', data: [strideHeight, strideWidth]}, // Stride.
216 {type: 'int32', data: [dilationHeight, dilationWidth]}, // Dilation.
217 {type: 'int32', data: [outWidth]},
218 {type: 'int32', data: [inChannels * filterWidth]}, // itemsPerBlockRow.
219 {type: 'int32', data: [inChannels]}
220 ];
221 const x2Col =
222 backend.runWebGPUProgram(im2ColProgram, [x], x.dtype, dimensions);
223
224 const intermediates: TensorInfo[] = [];
225 intermediates.push(x2Col);
226
227 const filterReshaped = reshape(
228 {inputs: {x: filter}, backend, attrs: {shape: [1, sharedDim, -1]}});
229 intermediates.push(filterReshaped);
230
231 if (preluActivationWeights != null) {
232 const targetShape =

Callers 1

conv2DImplFunction · 0.85

Calls 6

reshapeFunction · 0.90
batchMatMulImplFunction · 0.90
runWebGPUProgramMethod · 0.80
getShapeForBatchMatMulFunction · 0.70
disposeDataMethod · 0.65
pushMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…