MCPcopy Index your code
hub / github.com/tinygrad/tinygrad / compile_step

Function compile_step

examples/webgpu/stable_diffusion/compile.py:113–163  ·  view source on GitHub ↗
(model, step: Step)

Source from the content-addressed store, hash-verified

111 return code
112
113 def compile_step(model, step: Step):
114 linear, output_bufs = jit_model(step, *step.input)
115 functions, statements, bufs, _ = compile_net(linear, output_bufs)
116 state = get_state_dict(model)
117 weights = {(id(b), b.offset, b.size, b.dtype): name for name, x in state.items() if (b:=x.uop.base.realized) is not None}
118 kernel_code = '\n\n'.join([f"const {key} = `{fixup_code(code, key)}`;" for key, code in functions.items()])
119 kernel_names = ', '.join([name for (name, _, _, _) in statements])
120 input_names = [f"input{i}" for i in range(len(step.input))]
121 output_names = [f"output{i}" for i in range(len(output_bufs))]
122 input_buf_types = [dtype_to_js_type(bufs[inp_name][1]) for inp_name in input_names]
123 output_buf_types = [dtype_to_js_type(bufs[out_name][1]) for out_name in output_names]
124 kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
125 exported_bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weights else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weights[_key]}'], '{weights[_key]}'))") + ";" for name,(size,dtype,_key) in bufs.items()])
126 gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:input{i}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i in range(len(input_names))])
127 input_writer = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new {input_buf_types[i]}(gpuWriteBuffer{i}.getMappedRange()).set(" + f'data{i});' + f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);" for i,_ in enumerate(input_names)])
128 return f"""\n var {step.name} = function() {{
129
130 {kernel_code}
131
132 return {{
133 "setup": async (device, safetensor) => {{
134 const metadata = safetensor ? getTensorMetadata(safetensor[0]) : null;
135
136 {exported_bufs}
137
138 {gpu_write_bufs}
139 const gpuReadBuffer = device.createBuffer({{ size: output0.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});
140
141 const kernels = [{kernel_names}];
142 const piplines = await Promise.all(kernels.map(name => device.createComputePipelineAsync({{layout: "auto", compute: {{ module: device.createShaderModule({{ code: name }}), entryPoint: "main" }}}})));
143
144 return async ({",".join([f'data{i}' for i in range(len(input_names))])}) => {{
145 const commandEncoder = device.createCommandEncoder();
146
147 {input_writer}
148
149 {kernel_calls}
150 commandEncoder.copyBufferToBuffer(output0, 0, gpuReadBuffer, 0, output0.size);
151 const gpuCommands = commandEncoder.finish();
152 device.queue.submit([gpuCommands]);
153
154 await gpuReadBuffer.mapAsync(GPUMapMode.READ);
155 const resultBuffer = new {output_buf_types[0]}(gpuReadBuffer.size/{bufs[output_names[0]][1].itemsize});
156 resultBuffer.set(new {output_buf_types[0]}(gpuReadBuffer.getMappedRange()));
157 gpuReadBuffer.unmap();
158 return resultBuffer;
159 }}
160 }}
161 }}
162 }}
163 """
164
165 for step in sub_steps:
166 print(f'Executing step={step.name}')

Callers 1

compile.pyFile · 0.85

Calls 6

jit_modelFunction · 0.90
compile_netFunction · 0.90
get_state_dictFunction · 0.90
dtype_to_js_typeFunction · 0.90
idFunction · 0.85
fixup_codeFunction · 0.85

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…