(self, linear, input_uops=())
| 10 | |
| 11 | class MetalGraph(GraphRunner): |
| 12 | def __init__(self, linear, input_uops=()): |
| 13 | super().__init__(linear, input_uops) |
| 14 | self.dev = cast(MetalDevice, Device[self.device]) |
| 15 | |
| 16 | # create metal batch exec |
| 17 | icb_descriptor = metal.MTLIndirectCommandBufferDescriptor.new() |
| 18 | icb_descriptor.setCommandTypes(metal.MTLIndirectCommandTypeConcurrentDispatch) |
| 19 | icb_descriptor.setInheritBuffers(False) |
| 20 | icb_descriptor.setInheritPipelineState(False) |
| 21 | icb_descriptor.setMaxKernelBufferBindCount(31) |
| 22 | |
| 23 | self.icb = self.dev.sysdevice.newIndirectCommandBufferWithDescriptor_maxCommandCount_options(icb_descriptor, len(self.calls), |
| 24 | metal.MTLResourceCPUCacheModeDefaultCache) |
| 25 | if self.icb.value is None: raise GraphException("create indirect command buffer failed, does your system support this?") |
| 26 | self.needs_icb_fix = int(not self.dev.arch.startswith("Apple") or int(self.dev.arch[5:]) < 9) # ICB fix not required on M3+ (Apple9+) |
| 27 | |
| 28 | if len(self.vars): self.int_buf = self.dev.allocator.alloc(len(self.vars)*dtypes.int32.itemsize) |
| 29 | |
| 30 | all_pipelines, all_resources = [], [self.int_buf.buf] if len(self.vars) else [] |
| 31 | for j, ((_, ast, bufs, _), runtime, replace) in enumerate(zip(self.calls, self.runtimes, self.uop_replace)): |
| 32 | assert runtime is not None |
| 33 | icb_command = self.icb.indirectComputeCommandAtIndex(j).retained() |
| 34 | icb_command.setComputePipelineState(runtime.pipeline_state) |
| 35 | all_pipelines.append(runtime.pipeline_state) |
| 36 | for i, b in enumerate(bufs): |
| 37 | if not any(pos == i for pos, _ in replace): |
| 38 | icb_command.setKernelBuffer_offset_atIndex(b._buf.buf, b._buf.offset, i) |
| 39 | all_resources.append(b._buf.buf) |
| 40 | for i, v in enumerate(ast.arg.vars): icb_command.setKernelBuffer_offset_atIndex(self.int_buf.buf, self.vars.index(v.expr)*4, len(bufs)+i) |
| 41 | global_size, local_size = ast.arg.launch_dims({v: 0 for v in self.vars}) |
| 42 | icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup(metal.MTLSize(*global_size), metal.MTLSize(*local_size)) |
| 43 | icb_command.setBarrier() |
| 44 | |
| 45 | self.all_resources = dedup(all_resources) |
| 46 | self.all_pipelines = dedup(all_pipelines) |
| 47 | self.command_buffer: Any = None |
| 48 | if len(self.vars): self.int_buf_view = cast(MetalAllocator, self.dev.allocator)._as_buffer(self.int_buf).cast('i') |
| 49 | self.range = metal.NSRange(0, len(self.calls)) |
| 50 | self.updatable = sorted({j for j,r in enumerate(self.uop_replace) if r} | self.var_vals_replace.keys() | self.launch_dims_replace.keys()) |
| 51 | |
| 52 | def __call__(self, input_uops:tuple[UOp, ...], var_vals:dict[str, int], wait=False): |
| 53 | if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer) |
nothing calls this directly
no test coverage detected