MCPcopy
hub / github.com/tensorflow/tfjs / runWebGPUProgram

Method runWebGPUProgram

tfjs-backend-webgpu/src/backend_webgpu.ts:850–899  ·  view source on GitHub ↗
(
      program: webgpu_program.WebGPUProgram, inputs: TensorInfo[],
      outputDtype: DataType, programDefinedUniform?: ProgramUniform,
      output?: TensorInfo)

Source from the content-addressed store, hash-verified

848 }
849
850 public runWebGPUProgram(
851 program: webgpu_program.WebGPUProgram, inputs: TensorInfo[],
852 outputDtype: DataType, programDefinedUniform?: ProgramUniform,
853 output?: TensorInfo): TensorInfo {
854 if (!output) {
855 output = this.makeTensorInfo(program.outputShape, outputDtype);
856 }
857 if (util.sizeFromShape(output.shape) === 0) {
858 // Short-circuit the computation since the result is empty (has 0 in its
859 // shape).
860 this.tensorMap.get(output.dataId).values =
861 util.getTypedArrayFromDType(output.dtype as 'float32', 0);
862 return output;
863 }
864 this.uploadToGPU(output.dataId);
865 program.dispatch = reshapeDispatch(this.device, program);
866
867 const inputsData = inputs.map((input: TensorInfo, i: number) => {
868 if (input.dtype === 'complex64') {
869 throw new Error(
870 `GPGPUProgram does not support complex64 input. For complex64 ` +
871 `dtypes, please separate the program into real and imaginary ` +
872 `parts.`);
873 }
874 this.uploadToGPU(input.dataId);
875
876 return {
877 // Returning dtype from tensorMap because it reflects dtype
878 // of underlying buffer, rather than abstract dtype.
879 dtype: this.tensorMap.get(input.dataId).dtype,
880 shape: input.shape,
881 name: program.variableNames[i]
882 };
883 });
884
885 program.shaderKey =
886 webgpu_program.makeShaderKey(program, inputsData, output);
887
888 const parallelCompilation = env().getBool('WEBGPU_ENGINE_COMPILE_ONLY');
889 if (!(program.shaderKey in this.pipelineCache)) {
890 this.pipelineCache[program.shaderKey] = webgpu_program.compileProgram(
891 this.device, program, inputsData, output, parallelCompilation);
892 }
893 program.pipeline = this.pipelineCache[program.shaderKey];
894
895 if (!parallelCompilation) {
896 this.recordAndSubmit(program, output, inputs, programDefinedUniform);
897 }
898 return output;
899 }
900
901 private recordAndSubmit(
902 program: webgpu_program.WebGPUProgram, output: TensorInfo,

Callers 15

reduceFunction · 0.80
unaryKernelFuncFunction · 0.80
binaryKernelFuncFunction · 0.80
sparseSegmentReduceFunction · 0.80
intFunction · 0.80
selectFunction · 0.80
Square.tsFile · 0.80
concatImplFunction · 0.80
transformFunction · 0.80
resizeBilinearFunction · 0.80
fromPixelsFunction · 0.80
cropAndResizeFunction · 0.80

Calls 7

makeTensorInfoMethod · 0.95
uploadToGPUMethod · 0.95
recordAndSubmitMethod · 0.95
envFunction · 0.90
reshapeDispatchFunction · 0.85
getBoolMethod · 0.80
getMethod · 0.45

Tested by

no test coverage detected