MCPcopy Index your code
hub / github.com/tensorflow/tfjs / concatImpl

Function concatImpl

tfjs-backend-webgpu/src/kernels/Concat_impl.ts:29–138  ·  view source on GitHub ↗
(
    inputs: ConcatInputs, axis: number, backend: WebGPUBackend)

Source from the content-addressed store, hash-verified

27import {reshape} from './Reshape';
28
29export function concatImpl(
30 inputs: ConcatInputs, axis: number, backend: WebGPUBackend): TensorInfo {
31 const dtype = inputs[0].dtype;
32 if (dtype === 'complex64') {
33 const reals = inputs.map((t) => real({inputs: {input: t}, backend}));
34 const imags = inputs.map((t) => imag({inputs: {input: t}, backend}));
35
36 const realConcated = concatImpl(reals, axis, backend);
37 const imagConcated = concatImpl(imags, axis, backend);
38
39 const result =
40 complex({inputs: {real: realConcated, imag: imagConcated}, backend});
41
42 reals.forEach(r => backend.disposeData(r.dataId));
43 imags.forEach(i => backend.disposeData(i.dataId));
44 backend.disposeData(realConcated.dataId);
45 backend.disposeData(imagConcated.dataId);
46
47 return result;
48 }
49
50 let runOnCpu = backend.shouldExecuteOnCPU(inputs);
51
52 // Run on cpu if dtype is string. For string, the backend represents it
53 // as Uint8Array[], where each Uint8Array is a character. Given that the
54 // computation is only on the outer array, uploading the whole data onto
55 // gpu is wasteful. Also, currently webgpu doesn't have a design to
56 // upload and retrieve Uint8Array[] between cpu and gpu. Therefore, we
57 // just run the kernel on cpu if dtype is string.
58 if (dtype === 'string') {
59 runOnCpu = true;
60 }
61
62 if (runOnCpu) {
63 // Any concat of n-dimensional tensors across any axis can be reduced to
64 // a concatenation of two-dimensional tensors across the axis 1 by first
65 // partitioning the axes of the original tensors into those less than the
66 // axis to be concatenated and the rest. Then reshape the tensors
67 // into a two-dimensional tensor by collapsing these two sets of axes and
68 // concatenate the resulting matrices across the axis 1, finally reshaping
69 // the result to have the proper shape.
70 const tensors2D = inputs.map(t => {
71 const innerSize = util.sizeFromShape(t.shape.slice(axis));
72 const shape = [-1, innerSize];
73 return reshape({inputs: {x: t}, backend, attrs: {shape}});
74 });
75
76 const inputsValShapes = tensors2D.map(t => {
77 return {vals: backend.readSync(t.dataId), shape: t.shape};
78 });
79
80 // Concats 2d tensors along axis=1.
81 const outShape =
82 backend_util.computeOutShape(tensors2D.map(t => t.shape), 1 /* axis */);
83 const simplyConcat = tensors2D[0].shape[0] === 1;
84 const outVals =
85 concatImplCPU(inputsValShapes, outShape, dtype, simplyConcat);
86

Callers 1

concatFunction · 0.90

Calls 12

realFunction · 0.90
imagFunction · 0.90
complexFunction · 0.90
reshapeFunction · 0.90
runWebGPUProgramMethod · 0.80
computeTensors2DFunction · 0.70
disposeDataMethod · 0.65
sliceMethod · 0.65
readSyncMethod · 0.65
shouldExecuteOnCPUMethod · 0.45
makeTensorInfoMethod · 0.45
pushMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…