MCPcopy
hub / github.com/tinygrad/tinygrad / __init__

Method __init__

tinygrad/runtime/ops_metal.py:115–128  ·  view source on GitHub ↗
(self, dev:MetalDevice, name:str, lib:bytes, **kwargs)

Source from the content-addressed store, hash-verified

113
114class MetalProgram:
115 def __init__(self, dev:MetalDevice, name:str, lib:bytes, **kwargs):
116 self.dev, self.name, self.lib = dev, name, lib
117 data = objc.dispatch_data_create(lib, len(lib), None, None)
118 self.library = self.dev.sysdevice.newLibraryWithData_error(data, ctypes.byref(error_lib:=metal.NSError().retained())).retained()
119 error_check(error_lib)
120 self.fxn = self.library.newFunctionWithName(to_ns_str(name)).retained()
121 descriptor = metal.MTLComputePipelineDescriptor.new()
122 descriptor.setComputeFunction(self.fxn)
123 descriptor.setSupportIndirectCommandBuffers(True)
124 self.pipeline_state = self.dev.sysdevice.newComputePipelineStateWithDescriptor_options_reflection_error(descriptor, metal.MTLPipelineOptionNone,
125 None, ctypes.byref(error_pipeline_creation:=metal.NSError().retained()))
126 error_check(error_pipeline_creation)
127 # cache these msg calls
128 self.max_total_threads: int = self.pipeline_state.maxTotalThreadsPerThreadgroup()
129
130 def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw):
131 if prod(local_size) > self.max_total_threads:

Callers

nothing calls this directly

Calls 4

error_checkFunction · 0.85
to_ns_strFunction · 0.85
retainedMethod · 0.80
newMethod · 0.80

Tested by

no test coverage detected