MCPcopy Index your code
hub / github.com/tinygrad/tinygrad / export_model_webgpu

Function export_model_webgpu

extra/export_model.py:115–239  ·  view source on GitHub ↗
(functions, statements, bufs, weight_names, input_names, output_names, model_name, symbolic_vars={}, stream_weights=False)

Source from the content-addressed store, hash-verified

113 return f"{'Uint' if dtype in dtypes.uints else 'Int' if (dtype in dtypes.sints or dtype == dtypes.bool) else 'Float'}{8*dtype.itemsize}Array"
114
115def export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name, symbolic_vars={}, stream_weights=False) -> Tuple[str,int,int]:
116 kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
117 kernel_names = ', '.join([name for (name, _, _, _) in statements])
118 input_names += list(symbolic_vars.values())
119 input_buffer_types = [dtype_to_js_type(bufs[inp_name][1]) for inp_name in input_names]
120 output_buffer_types = [dtype_to_js_type(bufs[out_name][1]) for out_name in output_names]
121
122 buf_type = lambda x: "uniform" if x in set(symbolic_vars.values()) else "storage"
123 create_bind_group_layouts = ",".join([
124 "device.createBindGroupLayout({{entries: [{{binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: 'uniform' }}}}, {}]}})".format(
125 ",".join([f"{{binding: {argIdx+1}, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: '{buf_type(argName)}' }} }}" for argIdx, argName in enumerate(args)])
126 )
127 for _, (_, args, _, _) in enumerate(statements)
128 ])
129 layouts = f"const layouts=[{create_bind_group_layouts}]"
130 kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, pipelines[{i}], layouts[{i}], infinityBuf, [{', '.join(args)}], [{', '.join(str(x) for x in global_size)}]);" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
131
132 buf_type = lambda x: "createUniformBuf" if x in set(uop.arg[0] for uop in symbolic_vars) else "createEmptyBuf"
133 map_to_external_weight = lambda _key: f"state_dict['{weight_names[_key]}']" if stream_weights else f"getTensorBuffer(safetensor, metadata['{weight_names[_key]}'])"
134 _bufs = '\n '.join([f"const {name} = " + (f"{buf_type(_key)}(device, {size});" if _key not in weight_names else f"createWeightBuf(device, {size}, {map_to_external_weight(_key)})") + ";" for name,(size,dtype,_key) in bufs.items()])
135 gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:{input_name}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,input_name in enumerate(input_names)])
136 input_writers = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new {input_buffer_types[i]}(gpuWriteBuffer{i}.getMappedRange()).set(" + f'_{inp_name});' + f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);" for i,inp_name in enumerate(input_names)])
137 gpu_read_bufs = '\n '.join([f"const gpuReadBuffer{i} = device.createBuffer({{size:{output_name}.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});" for i,output_name in enumerate(output_names)])
138 outbuf_copies = '\n '.join([f"commandEncoder.copyBufferToBuffer({output_name}, 0, gpuReadBuffer{i}, 0, output{i}.size);" for i,output_name in enumerate(output_names)])
139 output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new {output_buffer_types[i]}(gpuReadBuffer{i}.size/{bufs[output_names[i]][1].itemsize});\n resultBuffer{i}.set(new {output_buffer_types[i]}(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))])
140 output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))]))
141 getTensorMetadata = f"""\nconst getTensorMetadata = (safetensorBuffer) => {{
142 const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
143 const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
144 return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {{...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}}]));
145}};\n""" if not stream_weights else ""
146 return f"""
147const {model_name} = (() => {{
148const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {{
149 return safetensorBuffer.subarray(...tensorMetadata.data_offsets);
150}};
151{getTensorMetadata}
152const createEmptyBuf = (device, size) => {{
153 return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }});
154}};
155
156const createUniformBuf = (device, size) => {{
157 return device.createBuffer({{size, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST}})
158}}
159
160const createInfinityUniformBuf = (device) => {{
161 const size = 4;
162 const buf = device.createBuffer({{
163 mappedAtCreation: true,
164 size,
165 usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST
166 }});
167 new Float32Array(buf.getMappedRange())[0] = Infinity;
168 buf.unmap();
169 return buf;
170}};
171
172const createWeightBuf = (device, size, data) => {{

Callers 1

export_modelFunction · 0.85

Calls 2

dtype_to_js_typeFunction · 0.85
replaceMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…