(self, *bufs:WGPUBufPtr, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1),
vals:tuple[int, ...]=(), wait=False, **kw)
| 90 | |
| 91 | self.name, self.lib, self.prg = name, lib, shader_module |
| 92 | def __call__(self, *bufs:WGPUBufPtr, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), |
| 93 | vals:tuple[int, ...]=(), wait=False, **kw) -> float|None: |
| 94 | wait = wait and self.timestamp_supported |
| 95 | tmp_bufs = [*bufs] |
| 96 | buf_patch = False |
| 97 | |
| 98 | # WebGPU does not allow using the same buffer for input and output |
| 99 | for i in range(1, len(bufs)): |
| 100 | if ctypes.addressof(bufs[i]) == ctypes.addressof(bufs[0]): |
| 101 | tmp_bufs[0] = webgpu.wgpuDeviceCreateBuffer(self.dev, |
| 102 | webgpu.WGPUBufferDescriptor(size=webgpu.wgpuBufferGetSize(bufs[0]), usage=webgpu.wgpuBufferGetUsage(bufs[0]))) |
| 103 | buf_patch = True |
| 104 | |
| 105 | # Creating bind group layout |
| 106 | binding_layouts = [webgpu.WGPUBindGroupLayoutEntry(binding=0, visibility= webgpu.WGPUShaderStage_Compute, |
| 107 | buffer=webgpu.WGPUBufferBindingLayout(type=webgpu.WGPUBufferBindingType_Uniform))] |
| 108 | binding_layouts += [webgpu.WGPUBindGroupLayoutEntry(binding=i+1, visibility=webgpu.WGPUShaderStage_Compute, |
| 109 | buffer=webgpu.WGPUBufferBindingLayout(type=webgpu.WGPUBufferBindingType_Uniform if i >= len(tmp_bufs) |
| 110 | else webgpu.WGPUBufferBindingType_Storage)) for i in range(len(tmp_bufs)+len(vals))] |
| 111 | |
| 112 | bl_arr_type = webgpu.WGPUBindGroupLayoutEntry * len(binding_layouts) |
| 113 | webgpu.wgpuDevicePushErrorScope(self.dev, webgpu.WGPUErrorFilter_Validation) |
| 114 | bind_group_layouts = [webgpu.wgpuDeviceCreateBindGroupLayout(self.dev, webgpu.WGPUBindGroupLayoutDescriptor( |
| 115 | entryCount=len(binding_layouts), entries=ctypes.cast(bl_arr_type(*binding_layouts), ctypes.POINTER(webgpu.WGPUBindGroupLayoutEntry))))] |
| 116 | |
| 117 | if bg_layout_err := pop_error(self.dev): raise RuntimeError(f"Error creating bind group layout: {bg_layout_err}") |
| 118 | |
| 119 | # Creating pipeline layout |
| 120 | pipeline_layout_desc = webgpu.WGPUPipelineLayoutDescriptor(bindGroupLayoutCount=len(bind_group_layouts), |
| 121 | bindGroupLayouts = (webgpu.WGPUBindGroupLayout * len(bind_group_layouts))(*bind_group_layouts)) |
| 122 | |
| 123 | webgpu.wgpuDevicePushErrorScope(self.dev, webgpu.WGPUErrorFilter_Validation) |
| 124 | pipeline_layout = webgpu.wgpuDeviceCreatePipelineLayout(self.dev, pipeline_layout_desc) |
| 125 | |
| 126 | if pipe_err := pop_error(self.dev): raise RuntimeError(f"Error creating pipeline layout: {pipe_err}") |
| 127 | |
| 128 | # Creating bind group |
| 129 | bindings = [webgpu.WGPUBindGroupEntry(binding=0, buffer=create_uniform(self.dev, float('inf')), offset=0, size=4)] |
| 130 | bindings += [webgpu.WGPUBindGroupEntry(binding=i+1, buffer=create_uniform(self.dev, cast(int, x)) if i >= len(tmp_bufs) else x, offset=0, |
| 131 | size=4 if i >= len(tmp_bufs) else webgpu.wgpuBufferGetSize(x)) for i,x in enumerate(tuple(tmp_bufs)+vals)] |
| 132 | |
| 133 | bg_arr_type = webgpu.WGPUBindGroupEntry * len(bindings) |
| 134 | bind_group_desc = webgpu.WGPUBindGroupDescriptor(layout=bind_group_layouts[0], entryCount=len(bindings), entries=bg_arr_type(*bindings)) |
| 135 | webgpu.wgpuDevicePushErrorScope(self.dev, webgpu.WGPUErrorFilter_Validation) |
| 136 | bind_group = webgpu.wgpuDeviceCreateBindGroup(self.dev, bind_group_desc) |
| 137 | |
| 138 | if bind_err := pop_error(self.dev): raise RuntimeError(f"Error creating bind group: {bind_err}") |
| 139 | |
| 140 | # Creating compute pipeline |
| 141 | compute_desc = webgpu.WGPUComputePipelineDescriptor(layout=pipeline_layout, |
| 142 | compute=webgpu.WGPUComputeState(module=self.prg, entryPoint=to_wgpu_str(self.name))) |
| 143 | pipeline_result = _run(webgpu.wgpuDeviceCreateComputePipelineAsync2, webgpu.WGPUCreateComputePipelineAsyncCallbackInfo2, |
| 144 | webgpu.WGPUCreateComputePipelineAsyncCallback2, webgpu.WGPUCreatePipelineAsyncStatus, 1, None, self.dev, compute_desc) |
| 145 | |
| 146 | command_encoder = webgpu.wgpuDeviceCreateCommandEncoder(self.dev, webgpu.WGPUCommandEncoderDescriptor()) |
| 147 | comp_pass_desc = webgpu.WGPUComputePassDescriptor() |
| 148 | |
| 149 | if wait: |
nothing calls this directly
no test coverage detected