MCPcopy
hub / github.com/tinygrad/tinygrad / export_model

Function export_model

extra/export_model.py:241–302  ·  view source on GitHub ↗
(model, target:str, *inputs, model_name: Optional[str] = "model", stream_weights=False)

Source from the content-addressed store, hash-verified

239"""
240
241def export_model(model, target:str, *inputs, model_name: Optional[str] = "model", stream_weights=False):
242 assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, f"only {', '.join(EXPORT_SUPPORTED_DEVICE)} are supported"
243
244 # NOTE: CPU_COUNT=1, since export does not support threading
245 with Context(JIT=2, CPU_COUNT=1): linear, output_bufs = jit_model(model, *inputs)
246 functions, statements, bufs, bufs_to_save = compile_net(linear, output_bufs)
247 state = get_state_dict(model)
248 weight_names = {(id(b), b.offset, b.size, b.dtype): name for name, x in state.items() if (b:=x.uop.base.realized) is not None}
249 input_names = [f"input{i}" for i in range(len(inputs))]
250 output_names = [f"output{i}" for i in range(len(output_bufs))]
251
252 # handle symbolic variables; TODO: refactor to fix some of this stuff upstream in tinygrad
253 symbolic_vars = OrderedDict()
254 for i, (_, args, global_size, _) in enumerate(statements):
255 for j, var in enumerate(args):
256 if getattr(var, "op", None) is Ops.DEFINE_VAR and isinstance(getattr(var, "arg", None), tuple) and isinstance(var.arg[0], str):
257 if var not in symbolic_vars:
258 symbolic_vars[var] = var.arg[0]
259 bufs[symbolic_vars[var]] = (var.dtype.itemsize, var.dtype, symbolic_vars[var])
260 statements[i][1][j] = symbolic_vars[var]
261
262 if global_size:
263 for j, dim in enumerate(global_size):
264 if getattr(dim, "op", None) is Ops.ADD and len(dim.src) == 2 and {dim.src[0].op, dim.src[1].op} == {Ops.DEFINE_VAR, Ops.CONST}:
265 name, val = dim.src if dim.src[1].op is Ops.CONST else reversed(dim.src)
266 global_size[j] = f"_{name.arg[0]}[0] + {val.arg}"
267
268 prg = ""
269 if target == "clang":
270 prg = export_model_clang(functions, statements, bufs, bufs_to_save, input_names, output_names)
271 elif target == "wasm":
272 return export_model_clang(functions, statements, bufs, bufs_to_save, input_names, output_names, weight_names, model_name, symbolic_vars, wasm=True)
273 elif target == "webgpu":
274 prg = export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name, symbolic_vars, stream_weights)
275 else:
276 prg = json.dumps({
277 "backend": Device.DEFAULT,
278 "inputs": [{
279 "size": bufs[name][0],
280 "dtype": bufs[name][1].name
281 } for name in input_names],
282 "outputs": [{
283 "size": bufs[name][0],
284 "dtype": bufs[name][1].name
285 } for name in output_names],
286 "functions": functions,
287 "statements": [{
288 "kernel": kernel,
289 "args": args,
290 "global_size": global_size,
291 "local_size": local_size
292 } for (kernel, args, global_size, local_size) in statements],
293 "buffers": {
294 name: {
295 "size": size,
296 "dtype": dtype.name,
297 "id": weight_names[_key] if _key in weight_names else ""
298 } for name, (size,dtype,_key) in bufs.items() if name not in ["input", "outputs"]

Calls 7

ContextClass · 0.90
get_state_dictFunction · 0.90
jit_modelFunction · 0.85
compile_netFunction · 0.85
idFunction · 0.85
export_model_clangFunction · 0.85
export_model_webgpuFunction · 0.85

Used in the wild real call sites across dependent graphs

searching dependent graphs…