(self, dev:MetalDevice, name:str, lib:bytes, **kwargs)
| 113 | |
| 114 | class MetalProgram: |
| 115 | def __init__(self, dev:MetalDevice, name:str, lib:bytes, **kwargs): |
| 116 | self.dev, self.name, self.lib = dev, name, lib |
| 117 | data = objc.dispatch_data_create(lib, len(lib), None, None) |
| 118 | self.library = self.dev.sysdevice.newLibraryWithData_error(data, ctypes.byref(error_lib:=metal.NSError().retained())).retained() |
| 119 | error_check(error_lib) |
| 120 | self.fxn = self.library.newFunctionWithName(to_ns_str(name)).retained() |
| 121 | descriptor = metal.MTLComputePipelineDescriptor.new() |
| 122 | descriptor.setComputeFunction(self.fxn) |
| 123 | descriptor.setSupportIndirectCommandBuffers(True) |
| 124 | self.pipeline_state = self.dev.sysdevice.newComputePipelineStateWithDescriptor_options_reflection_error(descriptor, metal.MTLPipelineOptionNone, |
| 125 | None, ctypes.byref(error_pipeline_creation:=metal.NSError().retained())) |
| 126 | error_check(error_pipeline_creation) |
| 127 | # cache these msg calls |
| 128 | self.max_total_threads: int = self.pipeline_state.maxTotalThreadsPerThreadgroup() |
| 129 | |
| 130 | def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw): |
| 131 | if prod(local_size) > self.max_total_threads: |
nothing calls this directly
no test coverage detected