MCPcopy Index your code
hub / github.com/tensorflow/tfjs / recordAndSubmit

Method recordAndSubmit

tfjs-backend-webgpu/src/backend_webgpu.ts:901–995  ·  view source on GitHub ↗
(
      program: webgpu_program.WebGPUProgram, output: TensorInfo,
      inputs: TensorInfo[], programDefinedUniform?: ProgramUniform)

Source from the content-addressed store, hash-verified

899 }
900
901 private recordAndSubmit(
902 program: webgpu_program.WebGPUProgram, output: TensorInfo,
903 inputs: TensorInfo[], programDefinedUniform?: ProgramUniform) {
904 if (program.pipeline instanceof Promise) {
905 throw new Error(
906 'Please call checkCompileCompletionAsync to ensure parallel compilation is done!');
907 }
908 // There are six kinds of uniforms: NAN, INFINITY, shapes, shape strides,
909 // program size, program defined uniforms.
910 let programUniform: ProgramUniform = [];
911 let bufferShapes: number[][] = [];
912 const uniformsType = 'int32';
913 if (program.pixelsOpType == null) {
914 programUniform.push(
915 {type: 'float32', data: [NaN]}, {type: 'float32', data: [Infinity]});
916 bufferShapes = inputs.concat(output).map(d => d.shape);
917 const uniformsType = 'int32';
918 bufferShapes.map(d => {
919 programUniform.push({type: uniformsType, data: d});
920 const strides = util.computeStrides(d);
921 programUniform.push({type: uniformsType, data: strides});
922 });
923 } else {
924 const strides = util.computeStrides(output.shape);
925 programUniform.push({type: uniformsType, data: strides});
926 }
927 if (program.size) {
928 const size = util.sizeFromShape(program.outputShape);
929 programUniform.push({
930 type: uniformsType,
931 data: [program.outputComponent ? size / program.outputComponent : size]
932 });
933 }
934
935 if (programDefinedUniform) {
936 programUniform = [...programUniform, ...programDefinedUniform];
937 }
938 const bindings = [
939 this.tensorToBinding(output), ...inputs.map(t => this.tensorToBinding(t)),
940 this.makeUniforms(programUniform)
941 ];
942
943 inputs.forEach(input => {
944 this.commandQueueOwnedIds.add(input.dataId);
945 });
946 this.commandQueueOwnedIds.add(output.dataId);
947
948 const bindGroup = this.device.createBindGroup({
949 layout: program.pipeline.getBindGroupLayout(0),
950 entries: bindings.map((b, i) => ({binding: i, resource: b})),
951 });
952
953 const shouldTimeProgram = this.activeTimers != null;
954 this.ensureCommandEncoderReady();
955
956 const computePassDescriptor: GPUComputePassDescriptor = {};
957 if (shouldTimeProgram && this.supportTimestampQuery) {
958 this.endComputePassEncoder();

Callers 1

runWebGPUProgramMethod · 0.95

Calls 11

tensorToBindingMethod · 0.95
makeUniformsMethod · 0.95
endComputePassEncoderMethod · 0.95
getQueryTimeMethod · 0.95
submitQueueMethod · 0.95
envFunction · 0.90
concatMethod · 0.65
addMethod · 0.65
pushMethod · 0.45
getMethod · 0.45

Tested by

no test coverage detected