MCPcopy Index your code
hub / github.com/tinygrad/tinygrad / __call__

Method __call__

tinygrad/runtime/ops_webgpu.py:92–177  ·  view source on GitHub ↗
(self, *bufs:WGPUBufPtr, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1),
               vals:tuple[int, ...]=(), wait=False, **kw)

Source from the content-addressed store, hash-verified

90
91 self.name, self.lib, self.prg = name, lib, shader_module
92 def __call__(self, *bufs:WGPUBufPtr, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1),
93 vals:tuple[int, ...]=(), wait=False, **kw) -> float|None:
94 wait = wait and self.timestamp_supported
95 tmp_bufs = [*bufs]
96 buf_patch = False
97
98 # WebGPU does not allow using the same buffer for input and output
99 for i in range(1, len(bufs)):
100 if ctypes.addressof(bufs[i]) == ctypes.addressof(bufs[0]):
101 tmp_bufs[0] = webgpu.wgpuDeviceCreateBuffer(self.dev,
102 webgpu.WGPUBufferDescriptor(size=webgpu.wgpuBufferGetSize(bufs[0]), usage=webgpu.wgpuBufferGetUsage(bufs[0])))
103 buf_patch = True
104
105 # Creating bind group layout
106 binding_layouts = [webgpu.WGPUBindGroupLayoutEntry(binding=0, visibility= webgpu.WGPUShaderStage_Compute,
107 buffer=webgpu.WGPUBufferBindingLayout(type=webgpu.WGPUBufferBindingType_Uniform))]
108 binding_layouts += [webgpu.WGPUBindGroupLayoutEntry(binding=i+1, visibility=webgpu.WGPUShaderStage_Compute,
109 buffer=webgpu.WGPUBufferBindingLayout(type=webgpu.WGPUBufferBindingType_Uniform if i >= len(tmp_bufs)
110 else webgpu.WGPUBufferBindingType_Storage)) for i in range(len(tmp_bufs)+len(vals))]
111
112 bl_arr_type = webgpu.WGPUBindGroupLayoutEntry * len(binding_layouts)
113 webgpu.wgpuDevicePushErrorScope(self.dev, webgpu.WGPUErrorFilter_Validation)
114 bind_group_layouts = [webgpu.wgpuDeviceCreateBindGroupLayout(self.dev, webgpu.WGPUBindGroupLayoutDescriptor(
115 entryCount=len(binding_layouts), entries=ctypes.cast(bl_arr_type(*binding_layouts), ctypes.POINTER(webgpu.WGPUBindGroupLayoutEntry))))]
116
117 if bg_layout_err := pop_error(self.dev): raise RuntimeError(f"Error creating bind group layout: {bg_layout_err}")
118
119 # Creating pipeline layout
120 pipeline_layout_desc = webgpu.WGPUPipelineLayoutDescriptor(bindGroupLayoutCount=len(bind_group_layouts),
121 bindGroupLayouts = (webgpu.WGPUBindGroupLayout * len(bind_group_layouts))(*bind_group_layouts))
122
123 webgpu.wgpuDevicePushErrorScope(self.dev, webgpu.WGPUErrorFilter_Validation)
124 pipeline_layout = webgpu.wgpuDeviceCreatePipelineLayout(self.dev, pipeline_layout_desc)
125
126 if pipe_err := pop_error(self.dev): raise RuntimeError(f"Error creating pipeline layout: {pipe_err}")
127
128 # Creating bind group
129 bindings = [webgpu.WGPUBindGroupEntry(binding=0, buffer=create_uniform(self.dev, float('inf')), offset=0, size=4)]
130 bindings += [webgpu.WGPUBindGroupEntry(binding=i+1, buffer=create_uniform(self.dev, cast(int, x)) if i >= len(tmp_bufs) else x, offset=0,
131 size=4 if i >= len(tmp_bufs) else webgpu.wgpuBufferGetSize(x)) for i,x in enumerate(tuple(tmp_bufs)+vals)]
132
133 bg_arr_type = webgpu.WGPUBindGroupEntry * len(bindings)
134 bind_group_desc = webgpu.WGPUBindGroupDescriptor(layout=bind_group_layouts[0], entryCount=len(bindings), entries=bg_arr_type(*bindings))
135 webgpu.wgpuDevicePushErrorScope(self.dev, webgpu.WGPUErrorFilter_Validation)
136 bind_group = webgpu.wgpuDeviceCreateBindGroup(self.dev, bind_group_desc)
137
138 if bind_err := pop_error(self.dev): raise RuntimeError(f"Error creating bind group: {bind_err}")
139
140 # Creating compute pipeline
141 compute_desc = webgpu.WGPUComputePipelineDescriptor(layout=pipeline_layout,
142 compute=webgpu.WGPUComputeState(module=self.prg, entryPoint=to_wgpu_str(self.name)))
143 pipeline_result = _run(webgpu.wgpuDeviceCreateComputePipelineAsync2, webgpu.WGPUCreateComputePipelineAsyncCallbackInfo2,
144 webgpu.WGPUCreateComputePipelineAsyncCallback2, webgpu.WGPUCreatePipelineAsyncStatus, 1, None, self.dev, compute_desc)
145
146 command_encoder = webgpu.wgpuDeviceCreateCommandEncoder(self.dev, webgpu.WGPUCommandEncoderDescriptor())
147 comp_pass_desc = webgpu.WGPUComputePassDescriptor()
148
149 if wait:

Callers

nothing calls this directly

Calls 9

pop_errorFunction · 0.85
create_uniformFunction · 0.85
castFunction · 0.85
to_wgpu_strFunction · 0.85
copy_buffer_to_bufferFunction · 0.85
read_bufferFunction · 0.85
_runFunction · 0.70
castMethod · 0.45
tolistMethod · 0.45

Tested by

no test coverage detected