(
program: webgpu_program.WebGPUProgram, inputs: TensorInfo[],
outputDtype: DataType, programDefinedUniform?: ProgramUniform,
output?: TensorInfo)
| 848 | } |
| 849 | |
| 850 | public runWebGPUProgram( |
| 851 | program: webgpu_program.WebGPUProgram, inputs: TensorInfo[], |
| 852 | outputDtype: DataType, programDefinedUniform?: ProgramUniform, |
| 853 | output?: TensorInfo): TensorInfo { |
| 854 | if (!output) { |
| 855 | output = this.makeTensorInfo(program.outputShape, outputDtype); |
| 856 | } |
| 857 | if (util.sizeFromShape(output.shape) === 0) { |
| 858 | // Short-circuit the computation since the result is empty (has 0 in its |
| 859 | // shape). |
| 860 | this.tensorMap.get(output.dataId).values = |
| 861 | util.getTypedArrayFromDType(output.dtype as 'float32', 0); |
| 862 | return output; |
| 863 | } |
| 864 | this.uploadToGPU(output.dataId); |
| 865 | program.dispatch = reshapeDispatch(this.device, program); |
| 866 | |
| 867 | const inputsData = inputs.map((input: TensorInfo, i: number) => { |
| 868 | if (input.dtype === 'complex64') { |
| 869 | throw new Error( |
| 870 | `GPGPUProgram does not support complex64 input. For complex64 ` + |
| 871 | `dtypes, please separate the program into real and imaginary ` + |
| 872 | `parts.`); |
| 873 | } |
| 874 | this.uploadToGPU(input.dataId); |
| 875 | |
| 876 | return { |
| 877 | // Returning dtype from tensorMap because it reflects dtype |
| 878 | // of underlying buffer, rather than abstract dtype. |
| 879 | dtype: this.tensorMap.get(input.dataId).dtype, |
| 880 | shape: input.shape, |
| 881 | name: program.variableNames[i] |
| 882 | }; |
| 883 | }); |
| 884 | |
| 885 | program.shaderKey = |
| 886 | webgpu_program.makeShaderKey(program, inputsData, output); |
| 887 | |
| 888 | const parallelCompilation = env().getBool('WEBGPU_ENGINE_COMPILE_ONLY'); |
| 889 | if (!(program.shaderKey in this.pipelineCache)) { |
| 890 | this.pipelineCache[program.shaderKey] = webgpu_program.compileProgram( |
| 891 | this.device, program, inputsData, output, parallelCompilation); |
| 892 | } |
| 893 | program.pipeline = this.pipelineCache[program.shaderKey]; |
| 894 | |
| 895 | if (!parallelCompilation) { |
| 896 | this.recordAndSubmit(program, output, inputs, programDefinedUniform); |
| 897 | } |
| 898 | return output; |
| 899 | } |
| 900 | |
| 901 | private recordAndSubmit( |
| 902 | program: webgpu_program.WebGPUProgram, output: TensorInfo, |
no test coverage detected