MCPcopy
hub / github.com/tinygrad/tinygrad / __init__

Method __init__

tinygrad/runtime/graph/metal.py:12–50  ·  view source on GitHub ↗
(self, linear, input_uops=())

Source from the content-addressed store, hash-verified

10
11class MetalGraph(GraphRunner):
12 def __init__(self, linear, input_uops=()):
13 super().__init__(linear, input_uops)
14 self.dev = cast(MetalDevice, Device[self.device])
15
16 # create metal batch exec
17 icb_descriptor = metal.MTLIndirectCommandBufferDescriptor.new()
18 icb_descriptor.setCommandTypes(metal.MTLIndirectCommandTypeConcurrentDispatch)
19 icb_descriptor.setInheritBuffers(False)
20 icb_descriptor.setInheritPipelineState(False)
21 icb_descriptor.setMaxKernelBufferBindCount(31)
22
23 self.icb = self.dev.sysdevice.newIndirectCommandBufferWithDescriptor_maxCommandCount_options(icb_descriptor, len(self.calls),
24 metal.MTLResourceCPUCacheModeDefaultCache)
25 if self.icb.value is None: raise GraphException("create indirect command buffer failed, does your system support this?")
26 self.needs_icb_fix = int(not self.dev.arch.startswith("Apple") or int(self.dev.arch[5:]) < 9) # ICB fix not required on M3+ (Apple9+)
27
28 if len(self.vars): self.int_buf = self.dev.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
29
30 all_pipelines, all_resources = [], [self.int_buf.buf] if len(self.vars) else []
31 for j, ((_, ast, bufs, _), runtime, replace) in enumerate(zip(self.calls, self.runtimes, self.uop_replace)):
32 assert runtime is not None
33 icb_command = self.icb.indirectComputeCommandAtIndex(j).retained()
34 icb_command.setComputePipelineState(runtime.pipeline_state)
35 all_pipelines.append(runtime.pipeline_state)
36 for i, b in enumerate(bufs):
37 if not any(pos == i for pos, _ in replace):
38 icb_command.setKernelBuffer_offset_atIndex(b._buf.buf, b._buf.offset, i)
39 all_resources.append(b._buf.buf)
40 for i, v in enumerate(ast.arg.vars): icb_command.setKernelBuffer_offset_atIndex(self.int_buf.buf, self.vars.index(v.expr)*4, len(bufs)+i)
41 global_size, local_size = ast.arg.launch_dims({v: 0 for v in self.vars})
42 icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup(metal.MTLSize(*global_size), metal.MTLSize(*local_size))
43 icb_command.setBarrier()
44
45 self.all_resources = dedup(all_resources)
46 self.all_pipelines = dedup(all_pipelines)
47 self.command_buffer: Any = None
48 if len(self.vars): self.int_buf_view = cast(MetalAllocator, self.dev.allocator)._as_buffer(self.int_buf).cast('i')
49 self.range = metal.NSRange(0, len(self.calls))
50 self.updatable = sorted({j for j,r in enumerate(self.uop_replace) if r} | self.var_vals_replace.keys() | self.launch_dims_replace.keys())
51
52 def __call__(self, input_uops:tuple[UOp, ...], var_vals:dict[str, int], wait=False):
53 if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer)

Callers

nothing calls this directly

Calls 12

GraphExceptionClass · 0.90
dedupFunction · 0.90
castFunction · 0.85
newMethod · 0.80
retainedMethod · 0.80
appendMethod · 0.80
launch_dimsMethod · 0.80
keysMethod · 0.80
allocMethod · 0.45
indexMethod · 0.45
castMethod · 0.45
_as_bufferMethod · 0.45

Tested by

no test coverage detected