(
inputs: ConcatInputs, axis: number, backend: WebGPUBackend)
| 27 | import {reshape} from './Reshape'; |
| 28 | |
| 29 | export function concatImpl( |
| 30 | inputs: ConcatInputs, axis: number, backend: WebGPUBackend): TensorInfo { |
| 31 | const dtype = inputs[0].dtype; |
| 32 | if (dtype === 'complex64') { |
| 33 | const reals = inputs.map((t) => real({inputs: {input: t}, backend})); |
| 34 | const imags = inputs.map((t) => imag({inputs: {input: t}, backend})); |
| 35 | |
| 36 | const realConcated = concatImpl(reals, axis, backend); |
| 37 | const imagConcated = concatImpl(imags, axis, backend); |
| 38 | |
| 39 | const result = |
| 40 | complex({inputs: {real: realConcated, imag: imagConcated}, backend}); |
| 41 | |
| 42 | reals.forEach(r => backend.disposeData(r.dataId)); |
| 43 | imags.forEach(i => backend.disposeData(i.dataId)); |
| 44 | backend.disposeData(realConcated.dataId); |
| 45 | backend.disposeData(imagConcated.dataId); |
| 46 | |
| 47 | return result; |
| 48 | } |
| 49 | |
| 50 | let runOnCpu = backend.shouldExecuteOnCPU(inputs); |
| 51 | |
| 52 | // Run on cpu if dtype is string. For string, the backend represents it |
| 53 | // as Uint8Array[], where each Uint8Array is a character. Given that the |
| 54 | // computation is only on the outer array, uploading the whole data onto |
| 55 | // gpu is wasteful. Also, currently webgpu doesn't have a design to |
| 56 | // upload and retrieve Uint8Array[] between cpu and gpu. Therefore, we |
| 57 | // just run the kernel on cpu if dtype is string. |
| 58 | if (dtype === 'string') { |
| 59 | runOnCpu = true; |
| 60 | } |
| 61 | |
| 62 | if (runOnCpu) { |
| 63 | // Any concat of n-dimensional tensors across any axis can be reduced to |
| 64 | // a concatenation of two-dimensional tensors across the axis 1 by first |
| 65 | // partitioning the axes of the original tensors into those less than the |
| 66 | // axis to be concatenated and the rest. Then reshape the tensors |
| 67 | // into a two-dimensional tensor by collapsing these two sets of axes and |
| 68 | // concatenate the resulting matrices across the axis 1, finally reshaping |
| 69 | // the result to have the proper shape. |
| 70 | const tensors2D = inputs.map(t => { |
| 71 | const innerSize = util.sizeFromShape(t.shape.slice(axis)); |
| 72 | const shape = [-1, innerSize]; |
| 73 | return reshape({inputs: {x: t}, backend, attrs: {shape}}); |
| 74 | }); |
| 75 | |
| 76 | const inputsValShapes = tensors2D.map(t => { |
| 77 | return {vals: backend.readSync(t.dataId), shape: t.shape}; |
| 78 | }); |
| 79 | |
| 80 | // Concats 2d tensors along axis=1. |
| 81 | const outShape = |
| 82 | backend_util.computeOutShape(tensors2D.map(t => t.shape), 1 /* axis */); |
| 83 | const simplyConcat = tensors2D[0].shape[0] === 1; |
| 84 | const outVals = |
| 85 | concatImplCPU(inputsValShapes, outShape, dtype, simplyConcat); |
| 86 |
no test coverage detected
searching dependent graphs…